[关闭]
@HUST-SuWB 2015-12-24T07:38:10.000000Z 字数 6657 阅读 834

基于Mahout的分类实例

项目实战


基本概念

分类算法是解决分类问题的方法,是数据挖掘、机器学习和模式识别中一个重要的研究领域。分类算法通过对已知类别训练集的分析,从中发现分类规则,以此预测新数据的类别。分类算法的应用非常广泛,银行中风险评估、客户类别分类、文本检索和搜索引擎分类、安全领域中的入侵检测以及软件项目中的应用等等。
本例使用的分类算法是逻辑回归。逻辑回归是将所有特征经过归一化处理,再用梯度下降法,求线性方程的解的一种分类方法。

需求分析

我的需求是基于项目结项的历史数据,建立项目结项的预测模型,预测具体项目是否能结项。通过对具体数据的分析,最后提炼出以下几个维度的数据作为预测因子,预测的目标变量是结项年数YEARS

维度名 简介
UNIVTYPE 学校结构类型(部署/地方)
UNIVCATEGORY 学校性质类型(综合/理工等)
LASTDEGREE 负责人学位
TITLE 负责人职称的级别
GENDER 负责人性别
AGEGROUP 负责人年龄段
PASSMID 此项目是否通过中检

代码实例

  1. import java.io.BufferedReader;
  2. import java.io.File;
  3. import java.io.FileOutputStream;
  4. import java.io.FileReader;
  5. import java.io.IOException;
  6. import java.io.OutputStream;
  7. import java.util.ArrayList;
  8. import java.util.HashMap;
  9. import java.util.HashSet;
  10. import java.util.List;
  11. import java.util.Locale;
  12. import java.util.Map;
  13. import org.apache.mahout.classifier.sgd.CsvRecordFactory;
  14. import org.apache.mahout.classifier.sgd.LogisticModelParameters;
  15. import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
  16. import org.apache.mahout.math.RandomAccessSparseVector;
  17. import org.apache.mahout.math.SequentialAccessSparseVector;
  18. import org.apache.mahout.math.Vector;
  19. import tool.ClassifierHelper;
  20. import tool.CsvTool;
  21. import com.google.common.collect.Lists;
  22. import com.google.common.io.Closeables;
  23. /**
  24. * 分类预测测试
  25. * @author suwb
  26. * @since 2015-12-17
  27. *
  28. */
  29. public class Classification {
  30. private static final String TRAINFILE = "F:\\source\\classification\\instp_end_train.csv";
  31. private static final String MODELFILE = "F:\\source\\classification\\instp_end.model";
  32. private static final String PREDICTFILE = "F:\source\\classification\\instp_end_predict.csv";
  33. private static final int NUMBER_OF_ARGUMENTS = 7;
  34. private static final int PASSES = 100;
  35. //训练数据/预测数据
  36. public void prepareData(int dataType){
  37. String sql = "SQL";
  38. if (dataType == 1) {//训练集
  39. sql += "SQL";
  40. }else if (dataType == 2) {//预测集
  41. sql += "SQL";
  42. }
  43. List<Object[]> dataList = dao.queryBySql(sql);//数据集
  44. String[] datas = null;
  45. HashSet sets = new HashSet();
  46. // 数据处理略(清洗、转换)
  47. String[] trainHeader = {"PROJECTID", "PROJECTNAME", "PROJECTYEAR", "UNIVTYPE", "UNIVCATEGORY", "LASTDEGREE", "TITLE",
  48. "GENDER", "AGEGROUP", "PASSMID", "YEARS"};
  49. String[] predictHeader = {"PROJECTID", "PROJECTNAME", "PROJECTYEAR", "UNIVTYPE", "UNIVCATEGORY", "LASTDEGREE", "TITLE",
  50. "GENDER", "AGEGROUP", "PASSMID"};
  51. if (dataType == 1) {
  52. CsvTool.writeCsv(TRAINFILE, trainHeader, dataList);
  53. }else CsvTool.writeCsv(PREDICTFILE, predictHeader, dataList);
  54. System.out.println(sets.toString());
  55. }
  56. public void train() throws Exception{
  57. LogisticModelParameters lmp = new LogisticModelParameters();
  58. lmp.setTargetVariable("YEARS"); //使用指定的变量作为目标(这里是YEARS)。
  59. lmp.setMaxTargetCategories(12); //目标变量的数量(这里是是结项年数)。
  60. lmp.setNumFeatures(NUMBER_OF_ARGUMENTS); //设置用于构建模型的特征向量大小,当输入为text-like类型的值时,大的值是更好的。
  61. lmp.setUseBias(true); //Eliminates the intercept term (a built-in constant predictor variable) from the model. Occasionally this is a good idea, but generally it isn’t since the SGD learning algorithm can usually eliminate the intercept term if warranted.是否使用常量bias,默认为1
  62. lmp.setLambda(0);
  63. lmp.setLearningRate(50);
  64. List<String> typeList = Lists.newArrayList(); //预测变量的类型,只能是 numeric, word, or text中的一种.
  65. List<String> predictorList = Lists.newArrayList();//指定预测因子(变量)的名称。
  66. typeList.add("numeric");
  67. predictorList.add("UNIVTYPE");
  68. typeList.add("numeric");
  69. predictorList.add("UNIVCATEGORY");
  70. typeList.add("numeric");
  71. predictorList.add("LASTDEGREE");
  72. typeList.add("numeric");
  73. predictorList.add("TITLE");
  74. typeList.add("numeric");
  75. predictorList.add("GENDER");
  76. typeList.add("numeric");
  77. predictorList.add("AGEGROUP");
  78. typeList.add("numeric");
  79. predictorList.add("PASSMID");
  80. lmp.setTypeMap(predictorList, typeList);//设置预测变量(因子)的类型,用于后面读取CSV中的数据.
  81. CsvRecordFactory csv = lmp.getCsvRecordFactory();//csv文件处理器的初始化
  82. OnlineLogisticRegression lr = lmp.createRegression();//生成预测模型,此处为空
  83. for (int pass = 0; pass < PASSES; pass++) {
  84. BufferedReader in = ClassifierHelper.open(TRAINFILE);//从输入的预测集文件读取数据;
  85. try {
  86. // 读取标题行,第一行变量名
  87. csv.firstLine(in.readLine());
  88. // 读取下一行,数据行第一行
  89. String line = in.readLine();
  90. while (line != null) {
  91. Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
  92. String[] values = line.split(",");
  93. // update model
  94. csv.processLine(line, input);
  95. lr.train(Integer.valueOf(values[10]), input);
  96. line = in.readLine();
  97. }
  98. }finally {
  99. Closeables.close(in, true);
  100. }
  101. }
  102. // 输出模型
  103. OutputStream modelOutput = new FileOutputStream(MODELFILE);
  104. try {
  105. lmp.saveTo(modelOutput);
  106. } finally {
  107. Closeables.close(modelOutput, false);
  108. }
  109. // 模型解析
  110. System.out.println(lmp.getNumFeatures());
  111. System.out.println(lmp.getTargetVariable() + " ~ ");
  112. String sep = "";
  113. for (String v : csv.getTraceDictionary().keySet()) {
  114. double weight = ClassifierHelper.predictorWeight(lr, 0, csv, v);
  115. if (weight != 0) {
  116. System.out.printf(Locale.SIMPLIFIED_CHINESE, "%s%.3f*%s", sep, weight, v);
  117. sep = " + ";
  118. }
  119. }
  120. System.out.printf("%n");
  121. }
  122. public void predict() throws IOException{
  123. LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(MODELFILE));
  124. CsvRecordFactory csv = lmp.getCsvRecordFactory();
  125. OnlineLogisticRegression lr = lmp.createRegression();
  126. BufferedReader in = ClassifierHelper.open(PREDICTFILE);
  127. String line = in.readLine();
  128. csv.firstLine(line);
  129. line = in.readLine();
  130. while (line != null) {
  131. Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
  132. csv.processLine(line, v);
  133. Vector a = lr.classify(v);//调用模型进行分类预测,返回的score值用于后续的评估
  134. int target = a.maxValueIndex();
  135. double score = a.get(target);
  136. System.out.println("预测结果为:" + target + "\n得分为:" + score);
  137. line = in.readLine();
  138. }
  139. }
  140. public void run() throws Exception{
  141. prepareData(1);
  142. prepareData(2);
  143. train();
  144. predict();
  145. }
  146. }
  1. package tool;
  2. import java.io.BufferedReader;
  3. import java.io.File;
  4. import java.io.FileInputStream;
  5. import java.io.IOException;
  6. import java.io.InputStream;
  7. import java.io.InputStreamReader;
  8. import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
  9. import org.apache.mahout.classifier.sgd.RecordFactory;
  10. import com.google.common.base.Charsets;
  11. import com.google.common.io.Resources;
  12. /**
  13. * 分类预测帮助类
  14. * @author Batys
  15. *
  16. */
  17. public class ClassifierHelper {
  18. /**
  19. * 打开并阅读输入文档
  20. * @param inputFile
  21. * @return
  22. * @throws IOException
  23. */
  24. public static BufferedReader open(String inputFile) throws IOException {
  25. InputStream in;
  26. try {
  27. in = Resources.getResource(inputFile).openStream();
  28. } catch (IllegalArgumentException e) {
  29. in = new FileInputStream(new File(inputFile));
  30. }
  31. return new BufferedReader(new InputStreamReader(in, Charsets.UTF_8));
  32. }
  33. /**
  34. * 获取预测变量的权重
  35. * @param lr 模型
  36. * @param row 行号
  37. * @param csv csv数据
  38. * @param predictor 预测变量名
  39. * @return predictor的权重值
  40. */
  41. public static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv, String predictor) {
  42. double weight = 0;
  43. for (Integer column : csv.getTraceDictionary().get(predictor)) {
  44. weight += lr.getBeta().get(row, column);
  45. }
  46. return weight;
  47. }
  48. }

代码解释

核心代码就是生产模型的过程,要注意对预测因子的类型进行限定,我这里把所有的原始文本型的数据全部转换成了类别型,个人觉得对模型的准确率有一定的帮助。
最后的模型解析出来的结果如下

  1. 4.617*AGEGROUP + -0.372*GENDER + -8.158*Intercept Term + -0.372*LASTDEGREE + -3.476*PASSMID + 3.066*TITLE + -8.158*UNIVCATEGORY + 4.617*UNIVTYPE

最后

提供一个有意义的参考资料:
http://blog.trifork.com/2014/02/04/an-introduction-to-mahouts-logistic-regression-sgd-classifier/
当然,Mahout本身提供example,大家可以在源码里找一找,或者看看Mahout的github:
https://github.com/apache/mahout

添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注