@HUST-SuWB
2015-12-24T07:38:10.000000Z
字数 6657
阅读 971
项目实战
分类算法是解决分类问题的方法,是数据挖掘、机器学习和模式识别中一个重要的研究领域。分类算法通过对已知类别训练集的分析,从中发现分类规则,以此预测新数据的类别。分类算法的应用非常广泛,银行中风险评估、客户类别分类、文本检索和搜索引擎分类、安全领域中的入侵检测以及软件项目中的应用等等。
本例使用的分类算法是逻辑回归。逻辑回归是将所有特征经过归一化处理,再用梯度下降法,求线性方程的解的一种分类方法。
我的需求是基于项目结项的历史数据,建立项目结项的预测模型,预测具体项目是否能结项。通过对具体数据的分析,最后提炼出以下几个维度的数据作为预测因子,预测的目标变量是结项年数YEARS。
| 维度名 | 简介 |
|---|---|
| UNIVTYPE | 学校结构类型(部署/地方) |
| UNIVCATEGORY | 学校性质类型(综合/理工等) |
| LASTDEGREE | 负责人学位 |
| TITLE | 负责人职称的级别 |
| GENDER | 负责人性别 |
| AGEGROUP | 负责人年龄段 |
| PASSMID | 此项目是否通过中检 |
import java.io.BufferedReader;import java.io.File;import java.io.FileOutputStream;import java.io.FileReader;import java.io.IOException;import java.io.OutputStream;import java.util.ArrayList;import java.util.HashMap;import java.util.HashSet;import java.util.List;import java.util.Locale;import java.util.Map;import org.apache.mahout.classifier.sgd.CsvRecordFactory;import org.apache.mahout.classifier.sgd.LogisticModelParameters;import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;import org.apache.mahout.math.RandomAccessSparseVector;import org.apache.mahout.math.SequentialAccessSparseVector;import org.apache.mahout.math.Vector;import tool.ClassifierHelper;import tool.CsvTool;import com.google.common.collect.Lists;import com.google.common.io.Closeables;/*** 分类预测测试* @author suwb* @since 2015-12-17**/public class Classification {private static final String TRAINFILE = "F:\\source\\classification\\instp_end_train.csv";private static final String MODELFILE = "F:\\source\\classification\\instp_end.model";private static final String PREDICTFILE = "F:\source\\classification\\instp_end_predict.csv";private static final int NUMBER_OF_ARGUMENTS = 7;private static final int PASSES = 100;//训练数据/预测数据public void prepareData(int dataType){String sql = "SQL";if (dataType == 1) {//训练集sql += "SQL";}else if (dataType == 2) {//预测集sql += "SQL";}List<Object[]> dataList = dao.queryBySql(sql);//数据集String[] datas = null;HashSet sets = new HashSet();// 数据处理略(清洗、转换)String[] trainHeader = {"PROJECTID", "PROJECTNAME", "PROJECTYEAR", "UNIVTYPE", "UNIVCATEGORY", "LASTDEGREE", "TITLE","GENDER", "AGEGROUP", "PASSMID", "YEARS"};String[] predictHeader = {"PROJECTID", "PROJECTNAME", "PROJECTYEAR", "UNIVTYPE", "UNIVCATEGORY", "LASTDEGREE", "TITLE","GENDER", "AGEGROUP", "PASSMID"};if (dataType == 1) {CsvTool.writeCsv(TRAINFILE, trainHeader, dataList);}else CsvTool.writeCsv(PREDICTFILE, predictHeader, dataList);System.out.println(sets.toString());}public void train() throws Exception{LogisticModelParameters lmp = new LogisticModelParameters();lmp.setTargetVariable("YEARS"); //使用指定的变量作为目标(这里是YEARS)。lmp.setMaxTargetCategories(12); //目标变量的数量(这里是是结项年数)。lmp.setNumFeatures(NUMBER_OF_ARGUMENTS); //设置用于构建模型的特征向量大小,当输入为text-like类型的值时,大的值是更好的。lmp.setUseBias(true); //Eliminates the intercept term (a built-in constant predictor variable) from the model. Occasionally this is a good idea, but generally it isn’t since the SGD learning algorithm can usually eliminate the intercept term if warranted.是否使用常量bias,默认为1lmp.setLambda(0);lmp.setLearningRate(50);List<String> typeList = Lists.newArrayList(); //预测变量的类型,只能是 numeric, word, or text中的一种.List<String> predictorList = Lists.newArrayList();//指定预测因子(变量)的名称。typeList.add("numeric");predictorList.add("UNIVTYPE");typeList.add("numeric");predictorList.add("UNIVCATEGORY");typeList.add("numeric");predictorList.add("LASTDEGREE");typeList.add("numeric");predictorList.add("TITLE");typeList.add("numeric");predictorList.add("GENDER");typeList.add("numeric");predictorList.add("AGEGROUP");typeList.add("numeric");predictorList.add("PASSMID");lmp.setTypeMap(predictorList, typeList);//设置预测变量(因子)的类型,用于后面读取CSV中的数据.CsvRecordFactory csv = lmp.getCsvRecordFactory();//csv文件处理器的初始化OnlineLogisticRegression lr = lmp.createRegression();//生成预测模型,此处为空for (int pass = 0; pass < PASSES; pass++) {BufferedReader in = ClassifierHelper.open(TRAINFILE);//从输入的预测集文件读取数据;try {// 读取标题行,第一行变量名csv.firstLine(in.readLine());// 读取下一行,数据行第一行String line = in.readLine();while (line != null) {Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());String[] values = line.split(",");// update modelcsv.processLine(line, input);lr.train(Integer.valueOf(values[10]), input);line = in.readLine();}}finally {Closeables.close(in, true);}}// 输出模型OutputStream modelOutput = new FileOutputStream(MODELFILE);try {lmp.saveTo(modelOutput);} finally {Closeables.close(modelOutput, false);}// 模型解析System.out.println(lmp.getNumFeatures());System.out.println(lmp.getTargetVariable() + " ~ ");String sep = "";for (String v : csv.getTraceDictionary().keySet()) {double weight = ClassifierHelper.predictorWeight(lr, 0, csv, v);if (weight != 0) {System.out.printf(Locale.SIMPLIFIED_CHINESE, "%s%.3f*%s", sep, weight, v);sep = " + ";}}System.out.printf("%n");}public void predict() throws IOException{LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(MODELFILE));CsvRecordFactory csv = lmp.getCsvRecordFactory();OnlineLogisticRegression lr = lmp.createRegression();BufferedReader in = ClassifierHelper.open(PREDICTFILE);String line = in.readLine();csv.firstLine(line);line = in.readLine();while (line != null) {Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());csv.processLine(line, v);Vector a = lr.classify(v);//调用模型进行分类预测,返回的score值用于后续的评估int target = a.maxValueIndex();double score = a.get(target);System.out.println("预测结果为:" + target + "\n得分为:" + score);line = in.readLine();}}public void run() throws Exception{prepareData(1);prepareData(2);train();predict();}}
package tool;import java.io.BufferedReader;import java.io.File;import java.io.FileInputStream;import java.io.IOException;import java.io.InputStream;import java.io.InputStreamReader;import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;import org.apache.mahout.classifier.sgd.RecordFactory;import com.google.common.base.Charsets;import com.google.common.io.Resources;/*** 分类预测帮助类* @author Batys**/public class ClassifierHelper {/*** 打开并阅读输入文档* @param inputFile* @return* @throws IOException*/public static BufferedReader open(String inputFile) throws IOException {InputStream in;try {in = Resources.getResource(inputFile).openStream();} catch (IllegalArgumentException e) {in = new FileInputStream(new File(inputFile));}return new BufferedReader(new InputStreamReader(in, Charsets.UTF_8));}/*** 获取预测变量的权重* @param lr 模型* @param row 行号* @param csv csv数据* @param predictor 预测变量名* @return predictor的权重值*/public static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv, String predictor) {double weight = 0;for (Integer column : csv.getTraceDictionary().get(predictor)) {weight += lr.getBeta().get(row, column);}return weight;}}
核心代码就是生产模型的过程,要注意对预测因子的类型进行限定,我这里把所有的原始文本型的数据全部转换成了类别型,个人觉得对模型的准确率有一定的帮助。
最后的模型解析出来的结果如下
4.617*AGEGROUP + -0.372*GENDER + -8.158*Intercept Term + -0.372*LASTDEGREE + -3.476*PASSMID + 3.066*TITLE + -8.158*UNIVCATEGORY + 4.617*UNIVTYPE
提供一个有意义的参考资料:
http://blog.trifork.com/2014/02/04/an-introduction-to-mahouts-logistic-regression-sgd-classifier/
当然,Mahout本身提供example,大家可以在源码里找一找,或者看看Mahout的github:
https://github.com/apache/mahout