[关闭]
@HUST-SuWB 2015-12-15T00:53:21.000000Z 字数 12307 阅读 593

基于Mahout的推荐实例

项目实战


基本概念

所谓推荐算法就是利用用户的一些行为,通过一些数学算法,推测出用户可能喜欢的东西。推荐算法的主要类型有基于内容的推荐、协同过滤推荐(基于用户、基于物品)、基于关联规则的推荐等。在Mahout的算法实现中,重点放在协同过滤推荐上,因为这种推荐方式是最通用的,理论上不受推荐对象的影响,可以用于各种事物的推荐。

需求分析

我的需求是为高校推荐他们后续可以合作的其他高校。基本思路是基于关联规则来推荐,即在历史数据中找出所有高校之间的关联度,以此关联度作为评分。当为某个具体的高校进行推荐的时候,推荐结果就是此高校没有合作过,但是是与跟此高校合作过的其他高校关联度最高的。
一个典型的输入数据如下

学校1 学校2 关联度
华中科技大学 武汉大学 10
武汉大学 华中科技大学 10
华中科技大学 武汉理工大学 8
武汉理工大学 华中科技大学 8

P.S. 给高校做推荐只是个例子,对于实际需求来说,给予学校推荐后续可以合作的高校并不合理,因为高校不是项目申报立项过程中的主体,主体是具体的人,所以人和人之间的关联关系放大到高校层面不具备指导意义。

代码实例

  1. import java.io.File;
  2. import java.io.IOException;
  3. import java.util.ArrayList;
  4. import java.util.Collection;
  5. import java.util.HashMap;
  6. import java.util.HashSet;
  7. import java.util.List;
  8. import java.util.Map;
  9. import java.util.Map.Entry;
  10. import java.util.Set;
  11. import org.apache.hadoop.conf.Configuration;
  12. import org.apache.hadoop.fs.FileSystem;
  13. import org.apache.hadoop.fs.Path;
  14. import org.apache.hadoop.io.SequenceFile;
  15. import org.apache.hadoop.io.Text;
  16. import org.apache.mahout.common.Pair;
  17. import org.apache.mahout.fpm.pfpgrowth.convertors.ContextStatusUpdater;
  18. import org.apache.mahout.fpm.pfpgrowth.convertors.SequenceFileOutputCollector;
  19. import org.apache.mahout.fpm.pfpgrowth.convertors.string.StringOutputConverter;
  20. import org.apache.mahout.fpm.pfpgrowth.convertors.string.TopKStringPatterns;
  21. import org.apache.mahout.fpm.pfpgrowth.fpgrowth.FPGrowth;
  22. import tool.CsvTool;
  23. /**
  24. * 关联分析测试
  25. * @author suwb
  26. * @since 2015-12-14
  27. */
  28. public class Association {
  29. private static int minSupport = 2;//设置最小支持度
  30. private static final String UNIVERSITY_PATH = "F:\\source\\association\\university_out.dat";
  31. private static final String CSV_PATH = "F:\\source\\association\\关联规则输出.csv";
  32. //准备源数据
  33. public List<Object[]> getData(){
  34. Dao jdbcDao = new Dao();
  35. //定义关联规则挖掘器的输入数据
  36. List<Object[]> dataItems = new ArrayList<Object[]>();
  37. //高校与高校之间的关联性
  38. //查询项目中有高校合作研究的情况,并转换为关联规则挖掘器的输入数据格式
  39. //所有立项项目所属高校
  40. List<Object[]> univs4Project = null;
  41. //所有立项项目项目成员所在高校
  42. List<Object[]> univs4Member = null;
  43. //所有立项项目所属高校
  44. String projectHql = "SQL";
  45. univs4Project = jdbcDao.queryBySql(projectHql);
  46. //所有立项项目项目成员所在高校
  47. String memberHql = "SQL";
  48. univs4Member = jdbcDao.queryBySql(memberHql);
  49. //项目id -> 高校集合的映射(set去重)
  50. Map<String, Set<String>> univMap = new HashMap<String, Set<String>>();
  51. //根据项目id,遍历项目所有成员的高校信息
  52. for (Object[] objs : univs4Member) {
  53. Set<String> univs = univMap.get(objs[0]);
  54. if (univs == null) {
  55. univs = new HashSet<String>();
  56. univMap.put((String)objs[0], univs);
  57. }
  58. univs.add((String)objs[1]);
  59. }
  60. //根据项目id,遍历项目的高校信息
  61. for (Object[] objs : univs4Project) {
  62. Set<String> univs = univMap.get(objs[0]);
  63. if (univs != null) {
  64. univs.add((String)objs[1]);
  65. }
  66. }
  67. //准备关联规则挖掘器的输入数据
  68. for (Entry<String, Set<String>> entry : univMap.entrySet()) {
  69. Set<String> univs = entry.getValue();
  70. dataItems.add(univs.toArray());
  71. }
  72. return dataItems;
  73. }
  74. //生成频繁模式,并序列化
  75. @SuppressWarnings("deprecation")
  76. public void getFrequentPatternFile() throws IOException{
  77. List<Object[]> dataItems = getData();
  78. //采用FP-bonsai pruning而实现更快的频繁模式增长(Frequent Pattern Growth)算法
  79. FPGrowth<String> fp = new FPGrowth<String>();
  80. // 所有事务集合
  81. Collection<Pair<List<String>, Long>> transactions = new ArrayList<Pair<List<String>, Long>>();
  82. // 构建transactions:pair事务集
  83. for (Object[] dataItem : dataItems) {
  84. List<String> list = new ArrayList<String>();
  85. for (int i = 0; i < dataItem.length; i++) {
  86. if (dataItem[i] != null) {
  87. list.add(String.valueOf(dataItem[i]));
  88. }
  89. }
  90. transactions.add(new Pair<List<String>, Long>(list, 1L));
  91. }
  92. //设置输出文件路径
  93. String tmpFilePath = UNIVERSITY_PATH;
  94. File tmpFile = new File(tmpFilePath);
  95. String tmpDirPath = null; //临时文件路径
  96. if (!tmpFile.exists()) {
  97. tmpDirPath = tmpFilePath.substring(0, tmpFilePath.indexOf("university_out.dat"));//从文件路径中截取文件夹得路径
  98. tmpFile = File.createTempFile("university_out", ".dat", new File(tmpDirPath));//在上一步的文件夹中创建新文件fpg_out.dat文件
  99. }
  100. Path path = new Path(tmpFile.getAbsolutePath());
  101. Configuration conf = new Configuration();
  102. FileSystem fs = FileSystem.get(conf);
  103. //构造序列化文件写入器
  104. SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, path, Text.class, TopKStringPatterns.class);
  105. // 在给定的事务流和最小支持度下,为每个属性生成前K频繁模
  106. fp.generateTopKFrequentPatterns(
  107. transactions.iterator(), //待挖掘的事务迭代器
  108. fp.generateFList(transactions.iterator(), (int) minSupport),
  109. minSupport, //最小支持度
  110. 1000, //各属性显示前K条(Number of top frequent patterns to keep)
  111. null,
  112. new StringOutputConverter(new SequenceFileOutputCollector<Text, TopKStringPatterns>(writer)),
  113. new ContextStatusUpdater(null));
  114. writer.close();
  115. }
  116. //将关联结果写入文件
  117. @SuppressWarnings("deprecation")
  118. public void writeAssociationToFile(){
  119. String tmpDirPath = UNIVERSITY_PATH.substring(0, UNIVERSITY_PATH.indexOf("university_out.dat"));//从文件路径中截取文件夹得路径
  120. File file = new File(tmpDirPath);
  121. File[] listFile = file.listFiles();
  122. Path path = null;
  123. for(File thisFile : listFile){
  124. if(thisFile.getName().contains("university") && !thisFile.getName().contains("crc")){
  125. path = new Path(thisFile.getAbsolutePath());
  126. }
  127. }
  128. Configuration conf = new Configuration();
  129. //调用mahout读取接口方法,从频繁模式库中读取频繁模式集
  130. List<Pair<String, TopKStringPatterns>> frequentPatterns = FPGrowth.readFrequentPattern(conf, path);
  131. Map<String, List<Object[]>> assoMap = new HashMap<String, List<Object[]>>();
  132. //对所有频繁模式进行遍历
  133. for (Pair<String, TopKStringPatterns> fps : frequentPatterns) {
  134. String key = fps.getFirst(); //如:key=华中科技大学
  135. TopKStringPatterns value = fps.getSecond();
  136. // System.out.println("key:" + key + " | value:" + value);
  137. List<Object[]> data = new ArrayList<Object[]>();
  138. //获取当前关键词key下所有模式
  139. List<Pair<List<String>, Long>> patterns = value.getPatterns();
  140. for(Pair<List<String>, Long> pair: patterns) {
  141. List<String> itemNames = pair.getFirst(); // 获取模式的元素名,如:[华中科技大学, 武汉大学, 湖北大学]
  142. Long occurrence = pair.getSecond(); // 获取模式的频繁度,如:10或6
  143. for(String itemName: itemNames) {
  144. if (!itemName.equals(key)) {//过滤掉key自身
  145. data.add(new Object[]{itemName, occurrence});
  146. }
  147. }
  148. }
  149. if (data.size() > 0) {
  150. assoMap.put(key, data);//map格式样例:{华中科技大学, [[武汉大学, 48], [湖北大学, 11]], [武汉大学, 7]]}
  151. }
  152. }
  153. List<String[]> dataList = new ArrayList<String[]>();//关联结果集
  154. for(String key : assoMap.keySet()){
  155. List<Object[]> value = assoMap.get(key);
  156. Map<String, Long> map = new HashMap<String, Long>();
  157. for(Object[] o : value){
  158. if(map.get(o[0]) == null){
  159. map.put(o[0].toString(), (Long)o[1]);
  160. }else {
  161. map.put(o[0].toString(), map.get(o[0]) + (Long)o[1]);
  162. }
  163. }
  164. for(String k : map.keySet()){
  165. dataList.add(new String[]{key, k, map.get(k).toString()});//dataList格式样例:[[华中科技大学, 武汉大学, 48], [武汉科技大学, 湖北大学, 11]]
  166. }
  167. }
  168. String[] header = {"user1", "user2", "评分"};
  169. CsvTool.writeCsv(CSV_PATH, header, dataList);
  170. }
  171. public void run(){
  172. try {
  173. getFrequentPatternFile();
  174. } catch (IOException e) {
  175. e.printStackTrace();
  176. }
  177. writeAssociationToFile();
  178. }
  179. }
  1. import java.io.File;
  2. import java.util.ArrayList;
  3. import java.util.HashMap;
  4. import java.util.List;
  5. import java.util.Map;
  6. import org.apache.mahout.cf.taste.common.TasteException;
  7. import org.apache.mahout.cf.taste.eval.IRStatistics;
  8. import org.apache.mahout.cf.taste.eval.RecommenderBuilder;
  9. import org.apache.mahout.cf.taste.eval.RecommenderEvaluator;
  10. import org.apache.mahout.cf.taste.eval.RecommenderIRStatsEvaluator;
  11. import org.apache.mahout.cf.taste.impl.eval.AverageAbsoluteDifferenceRecommenderEvaluator;
  12. import org.apache.mahout.cf.taste.impl.eval.GenericRecommenderIRStatsEvaluator;
  13. import org.apache.mahout.cf.taste.impl.model.file.FileDataModel;
  14. import org.apache.mahout.cf.taste.impl.neighborhood.NearestNUserNeighborhood;
  15. import org.apache.mahout.cf.taste.impl.recommender.GenericUserBasedRecommender;
  16. import org.apache.mahout.cf.taste.impl.similarity.PearsonCorrelationSimilarity;
  17. import org.apache.mahout.cf.taste.model.DataModel;
  18. import org.apache.mahout.cf.taste.neighborhood.UserNeighborhood;
  19. import org.apache.mahout.cf.taste.recommender.RecommendedItem;
  20. import org.apache.mahout.cf.taste.recommender.Recommender;
  21. import org.apache.mahout.cf.taste.similarity.UserSimilarity;
  22. import tool.CsvTool;
  23. /**
  24. * 推荐测试
  25. * Mahout中的推荐由多个组件组成
  26. * 1、数据模型,由DataModel实现;
  27. * 2、用户间的相似性度量 ,由UserSimilarity实现;
  28. * 3、用户邻域的定义,由UserNeighborhood实现;
  29. * 4、推荐引擎,由一个Recommender实现。
  30. * 基本原理见《Mahout实战》P36
  31. * @author suwb
  32. * @since 2015-12-08
  33. */
  34. public class Recommendation {
  35. private static final String ASSOCIATION_SOURCE = "F:\\source\\association\\关联规则输出.csv";
  36. private static final String RECOMMENDATION_SOURCE = "F:\\source\\recommendation\\推荐算法输入数据.csv";
  37. private static Map<String, Integer> map;
  38. private static String userName = "华中科技大学";
  39. public Recommendation(){
  40. if(map == null){
  41. initName2IdMap();
  42. }
  43. }
  44. //运行推荐算法
  45. public List<RecommendedItem> work(long userId) throws Exception{
  46. DataModel model = new FileDataModel(new File(RECOMMENDATION_SOURCE));
  47. UserSimilarity similarity = new PearsonCorrelationSimilarity(model);
  48. UserNeighborhood neighborhood = new NearestNUserNeighborhood(50, similarity, model);
  49. Recommender recommender = new GenericUserBasedRecommender(model, neighborhood, similarity);
  50. return recommender.recommend(userId, 5);
  51. }
  52. //算法效果评分
  53. //评分越小越好
  54. public void evaluateScore() throws Exception{
  55. DataModel model = new FileDataModel(new File(RECOMMENDATION_SOURCE));
  56. RecommenderEvaluator evaluator = new AverageAbsoluteDifferenceRecommenderEvaluator();
  57. RecommenderBuilder builder = new RecommenderBuilder() {
  58. @Override
  59. public Recommender buildRecommender(DataModel model) throws TasteException {
  60. UserSimilarity similarity = new PearsonCorrelationSimilarity(model);
  61. UserNeighborhood neighborhood = new NearestNUserNeighborhood(90, similarity, model);
  62. return new GenericUserBasedRecommender(model, neighborhood, similarity);
  63. }
  64. };
  65. double score = evaluator.evaluate(builder, null, model, 0.7, 1);
  66. System.out.println(score);
  67. }
  68. //评估查准率和查全率
  69. public void evaluateRate() throws Exception{
  70. DataModel model = new FileDataModel(new File(RECOMMENDATION_SOURCE));
  71. RecommenderIRStatsEvaluator evaluator = new GenericRecommenderIRStatsEvaluator();
  72. RecommenderBuilder builder = new RecommenderBuilder() {
  73. @Override
  74. public Recommender buildRecommender(DataModel model) throws TasteException {
  75. UserSimilarity similarity = new PearsonCorrelationSimilarity(model);
  76. UserNeighborhood neighborhood = new NearestNUserNeighborhood(90, similarity, model);
  77. return new GenericUserBasedRecommender(model, neighborhood, similarity);
  78. }
  79. };
  80. IRStatistics stats = evaluator.evaluate(builder, null, model, null, 5,
  81. GenericRecommenderIRStatsEvaluator.CHOOSE_THRESHOLD,
  82. 1);
  83. System.out.println(stats.getPrecision());
  84. System.out.println(stats.getRecall());
  85. }
  86. public String getNameById(int id){
  87. String name = "";
  88. for(String key : map.keySet()){
  89. if(map.get(key) == id){
  90. name = key;
  91. break;
  92. }
  93. }
  94. return name;
  95. }
  96. public void initName2IdMap(){
  97. map = new HashMap<String, Integer>();
  98. Dao dao = new Dao();
  99. int i = 1001;
  100. List<Object[]> unitName = dao.queryBySql("select c_name from t_agency where c_type='3' or c_type='4'");
  101. for(Object[] o : unitName){
  102. map.put(o[0].toString(), i);
  103. i++;
  104. }
  105. }
  106. public void run() throws Exception{
  107. List<Object[]> dataList = CsvTool.readCsv(ASSOCIATION_SOURCE);
  108. List<String[]> data = new ArrayList<String[]>();
  109. for(int i=1; i<dataList.size(); i++){
  110. Object[] o = dataList.get(i);
  111. data.add(new String[]{map.get(o[0]).toString(), map.get(o[1]).toString(), o[2].toString()});
  112. }
  113. CsvTool.writeCsv(RECOMMENDATION_SOURCE, null, data);
  114. evaluateScore();
  115. evaluateRate();
  116. List<RecommendedItem> recommendedList = work(new Long(map.get(userName)));
  117. for (RecommendedItem item : recommendedList) {
  118. System.out.printf("(%s,%f)", getNameById(Integer.parseInt(String.valueOf(item.getItemID()))), item.getValue());
  119. }
  120. }
  121. }
  1. package tool;
  2. import java.io.IOException;
  3. import java.nio.charset.Charset;
  4. import java.util.ArrayList;
  5. import java.util.List;
  6. import com.csvreader.CsvReader;
  7. import com.csvreader.CsvWriter;
  8. /**
  9. * CSV工具包
  10. * @author suwb
  11. */
  12. public class CsvTool {
  13. /**
  14. * 写CSV文件
  15. * @param outFilePath 数据文件路径
  16. * @param header 内容第一行标题
  17. * @param dataList 内容
  18. */
  19. public static void writeCsv(String outFilePath, String[] header, List<String[]> dataList) {
  20. CsvWriter writer = null;
  21. try {
  22. writer = new CsvWriter(outFilePath, ',', Charset.forName("UTF-8"));
  23. //写文件头
  24. if(header != null){
  25. writer.writeRecord(header);
  26. }
  27. //写文件内容
  28. for (String[] datas : dataList) {
  29. writer.writeRecord(datas);
  30. }
  31. } catch (IOException e) {
  32. e.printStackTrace();
  33. } finally {
  34. writer.close();
  35. }
  36. }
  37. /**
  38. * 读CSV文件
  39. * @param csvFilePath
  40. * @throws Exception
  41. */
  42. public static List<Object[]> readCsv(String csvFilePath) throws Exception {
  43. // 返回结果
  44. List<Object[]> datas = new ArrayList<Object[]>();
  45. CsvReader reader = new CsvReader(csvFilePath, ',', Charset.forName("UTF-8"));
  46. //读文件内容
  47. while (reader.readRecord()) {
  48. datas.add(reader.getValues());
  49. }
  50. return datas;
  51. }
  52. }

代码解释

推荐由以下几个组件组成:
1、数据模型,由DataModel实现;
2、用户间的相似性度量 ,由UserSimilarity实现;
3、用户邻域的定义,由UserNeighborhood实现;
4、推荐引擎,由一个Recommender实现。
做好这几部就能得到基本的推荐结果,剩下的就是调优了。
我这里有个问题在于,针对当下的推荐需求,我拿不到评分的数据,所以我的推荐程序的输入数据来源于先进行了一次关联分析得到的结果,以高校间的关联度作为高校的评分。

最后

几个典型的推荐结果如下:

  1. 对武汉大学的推荐:
  2. (武汉大学,24.060883)(西北师范大学,10.000000)(黑龙江大学,6.448706)(西安交通大学,5.522729)(兰州交通大学,5.000000)
  3. 对华中科技大学的推荐:
  4. (华中科技大学,27.968401)(新疆大学,16.533577)(新疆医科大学,12.936143)(海南大学,7.918249)(西安交通大学,6.500000)
  5. 对清华大学的推荐:
  6. (清华大学,10.495702)(中南财经政法大学,7.616104)(哈尔滨师范大学,5.000000)(华东政法大学,4.476559)(浙江工商大学,3.679918)
  7. 对北京大学的推荐:
  8. (新疆师范大学,7.000000)(南京师范大学,7.000000)(西安交通大学,5.500000)(聊城大学,5.468179)(西华师范大学,4.000000)
  9. 对电子科技大学的推荐:
  10. (广东工业大学,19.635860)(浙江大学,15.270090)(南京理工大学,15.048809)(广东财经大学,13.014910)(湖南商学院,10.524227)

解释一下为什么推荐结果集中会出现自己。原因在于我的输入数据中,user1、user2是对等的,就像我在需求分析里举的典型的输入数据,两所高校间的关联数据会出现两次,也就是说对于一所高校来说,它既是user,也是item。所以对一个user推荐一个item的时候,是可能推荐到自己的。而这里结果多次出现了自己,也正证实了推荐算法的有效性。算法是在一大堆的数据中发现item中的"华中科技大学"特别适合user中的"华中科技大学"的。

添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注