@HUST-SuWB
2015-12-24T07:38:10.000000Z
字数 6657
阅读 834
项目实战
分类算法是解决分类问题的方法,是数据挖掘、机器学习和模式识别中一个重要的研究领域。分类算法通过对已知类别训练集的分析,从中发现分类规则,以此预测新数据的类别。分类算法的应用非常广泛,银行中风险评估、客户类别分类、文本检索和搜索引擎分类、安全领域中的入侵检测以及软件项目中的应用等等。
本例使用的分类算法是逻辑回归。逻辑回归是将所有特征经过归一化处理,再用梯度下降法,求线性方程的解的一种分类方法。
我的需求是基于项目结项的历史数据,建立项目结项的预测模型,预测具体项目是否能结项。通过对具体数据的分析,最后提炼出以下几个维度的数据作为预测因子,预测的目标变量是结项年数YEARS。
维度名 | 简介 |
---|---|
UNIVTYPE | 学校结构类型(部署/地方) |
UNIVCATEGORY | 学校性质类型(综合/理工等) |
LASTDEGREE | 负责人学位 |
TITLE | 负责人职称的级别 |
GENDER | 负责人性别 |
AGEGROUP | 负责人年龄段 |
PASSMID | 此项目是否通过中检 |
import java.io.BufferedReader;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import org.apache.mahout.classifier.sgd.CsvRecordFactory;
import org.apache.mahout.classifier.sgd.LogisticModelParameters;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import tool.ClassifierHelper;
import tool.CsvTool;
import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
/**
* 分类预测测试
* @author suwb
* @since 2015-12-17
*
*/
public class Classification {
private static final String TRAINFILE = "F:\\source\\classification\\instp_end_train.csv";
private static final String MODELFILE = "F:\\source\\classification\\instp_end.model";
private static final String PREDICTFILE = "F:\source\\classification\\instp_end_predict.csv";
private static final int NUMBER_OF_ARGUMENTS = 7;
private static final int PASSES = 100;
//训练数据/预测数据
public void prepareData(int dataType){
String sql = "SQL";
if (dataType == 1) {//训练集
sql += "SQL";
}else if (dataType == 2) {//预测集
sql += "SQL";
}
List<Object[]> dataList = dao.queryBySql(sql);//数据集
String[] datas = null;
HashSet sets = new HashSet();
// 数据处理略(清洗、转换)
String[] trainHeader = {"PROJECTID", "PROJECTNAME", "PROJECTYEAR", "UNIVTYPE", "UNIVCATEGORY", "LASTDEGREE", "TITLE",
"GENDER", "AGEGROUP", "PASSMID", "YEARS"};
String[] predictHeader = {"PROJECTID", "PROJECTNAME", "PROJECTYEAR", "UNIVTYPE", "UNIVCATEGORY", "LASTDEGREE", "TITLE",
"GENDER", "AGEGROUP", "PASSMID"};
if (dataType == 1) {
CsvTool.writeCsv(TRAINFILE, trainHeader, dataList);
}else CsvTool.writeCsv(PREDICTFILE, predictHeader, dataList);
System.out.println(sets.toString());
}
public void train() throws Exception{
LogisticModelParameters lmp = new LogisticModelParameters();
lmp.setTargetVariable("YEARS"); //使用指定的变量作为目标(这里是YEARS)。
lmp.setMaxTargetCategories(12); //目标变量的数量(这里是是结项年数)。
lmp.setNumFeatures(NUMBER_OF_ARGUMENTS); //设置用于构建模型的特征向量大小,当输入为text-like类型的值时,大的值是更好的。
lmp.setUseBias(true); //Eliminates the intercept term (a built-in constant predictor variable) from the model. Occasionally this is a good idea, but generally it isn’t since the SGD learning algorithm can usually eliminate the intercept term if warranted.是否使用常量bias,默认为1
lmp.setLambda(0);
lmp.setLearningRate(50);
List<String> typeList = Lists.newArrayList(); //预测变量的类型,只能是 numeric, word, or text中的一种.
List<String> predictorList = Lists.newArrayList();//指定预测因子(变量)的名称。
typeList.add("numeric");
predictorList.add("UNIVTYPE");
typeList.add("numeric");
predictorList.add("UNIVCATEGORY");
typeList.add("numeric");
predictorList.add("LASTDEGREE");
typeList.add("numeric");
predictorList.add("TITLE");
typeList.add("numeric");
predictorList.add("GENDER");
typeList.add("numeric");
predictorList.add("AGEGROUP");
typeList.add("numeric");
predictorList.add("PASSMID");
lmp.setTypeMap(predictorList, typeList);//设置预测变量(因子)的类型,用于后面读取CSV中的数据.
CsvRecordFactory csv = lmp.getCsvRecordFactory();//csv文件处理器的初始化
OnlineLogisticRegression lr = lmp.createRegression();//生成预测模型,此处为空
for (int pass = 0; pass < PASSES; pass++) {
BufferedReader in = ClassifierHelper.open(TRAINFILE);//从输入的预测集文件读取数据;
try {
// 读取标题行,第一行变量名
csv.firstLine(in.readLine());
// 读取下一行,数据行第一行
String line = in.readLine();
while (line != null) {
Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
String[] values = line.split(",");
// update model
csv.processLine(line, input);
lr.train(Integer.valueOf(values[10]), input);
line = in.readLine();
}
}finally {
Closeables.close(in, true);
}
}
// 输出模型
OutputStream modelOutput = new FileOutputStream(MODELFILE);
try {
lmp.saveTo(modelOutput);
} finally {
Closeables.close(modelOutput, false);
}
// 模型解析
System.out.println(lmp.getNumFeatures());
System.out.println(lmp.getTargetVariable() + " ~ ");
String sep = "";
for (String v : csv.getTraceDictionary().keySet()) {
double weight = ClassifierHelper.predictorWeight(lr, 0, csv, v);
if (weight != 0) {
System.out.printf(Locale.SIMPLIFIED_CHINESE, "%s%.3f*%s", sep, weight, v);
sep = " + ";
}
}
System.out.printf("%n");
}
public void predict() throws IOException{
LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(MODELFILE));
CsvRecordFactory csv = lmp.getCsvRecordFactory();
OnlineLogisticRegression lr = lmp.createRegression();
BufferedReader in = ClassifierHelper.open(PREDICTFILE);
String line = in.readLine();
csv.firstLine(line);
line = in.readLine();
while (line != null) {
Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
csv.processLine(line, v);
Vector a = lr.classify(v);//调用模型进行分类预测,返回的score值用于后续的评估
int target = a.maxValueIndex();
double score = a.get(target);
System.out.println("预测结果为:" + target + "\n得分为:" + score);
line = in.readLine();
}
}
public void run() throws Exception{
prepareData(1);
prepareData(2);
train();
predict();
}
}
package tool;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.classifier.sgd.RecordFactory;
import com.google.common.base.Charsets;
import com.google.common.io.Resources;
/**
* 分类预测帮助类
* @author Batys
*
*/
public class ClassifierHelper {
/**
* 打开并阅读输入文档
* @param inputFile
* @return
* @throws IOException
*/
public static BufferedReader open(String inputFile) throws IOException {
InputStream in;
try {
in = Resources.getResource(inputFile).openStream();
} catch (IllegalArgumentException e) {
in = new FileInputStream(new File(inputFile));
}
return new BufferedReader(new InputStreamReader(in, Charsets.UTF_8));
}
/**
* 获取预测变量的权重
* @param lr 模型
* @param row 行号
* @param csv csv数据
* @param predictor 预测变量名
* @return predictor的权重值
*/
public static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv, String predictor) {
double weight = 0;
for (Integer column : csv.getTraceDictionary().get(predictor)) {
weight += lr.getBeta().get(row, column);
}
return weight;
}
}
核心代码就是生产模型的过程,要注意对预测因子的类型进行限定,我这里把所有的原始文本型的数据全部转换成了类别型,个人觉得对模型的准确率有一定的帮助。
最后的模型解析出来的结果如下
4.617*AGEGROUP + -0.372*GENDER + -8.158*Intercept Term + -0.372*LASTDEGREE + -3.476*PASSMID + 3.066*TITLE + -8.158*UNIVCATEGORY + 4.617*UNIVTYPE
提供一个有意义的参考资料:
http://blog.trifork.com/2014/02/04/an-introduction-to-mahouts-logistic-regression-sgd-classifier/
当然,Mahout本身提供example,大家可以在源码里找一找,或者看看Mahout的github:
https://github.com/apache/mahout