[关闭]
@jfruan 2017-04-14T00:50:58.000000Z 字数 15958 阅读 1643

FP-Tree算法的实现

DM


在关联规则挖掘领域最经典的算法法是Apriori,其致命的缺点是需要多次扫描事务数据库。于是人们提出了各种裁剪(prune)数据集的方法以减少I/O开支,韩嘉炜老师的FP-Tree算法就是其中非常高效的一种。

1. 名词约定

举个例子,设事务数据库为:

A  E  F  G
A  F  G
A  B  E  F  G
E  F  G

每一行为一个事务,事务由若干个互不相同的项目构成,任意几个项目的组合称为一个模式。
上例中一共有4个事务。

模式{A,F,G}的支持数为3,支持度为3/4。支持数大于阈值minSuport的模式称为频繁模式(Frequent Patten)。

{F,G}的支持度数为4,支持度为4/4。

{A}的支持度数为3,支持度为3/4。

{F,G}=>{A}的置信度为:{A,F,G}的支持度数 除以 {F,G}的支持度数,即3/4

{A}=>{F,G}的置信度为:{A,F,G}的支持度数 除以 {A}的支持度数,即3/3

强关联规则挖掘是在满足一定支持度的情况下寻找置信度达到阈值的所有模式。

2. FP-Tree算法描述

算法描述:

  1. 输入:事务集合 List<List<String>> transactions
  2. 输出:频繁模式集合及相应的频数 Map<List<String>,Integer> FrequentPattens
  3. 初始化 PostModel=[],CPB=transactions
  4. void FPGrowth(List<List<String>> CPB,List<String> PostModel){
  5. if CPB为空:
  6. return
  7. 统计CPB中每一个项目的计数,把计数小于最小支持数minSuport的删除掉,对于CPB中的每一条事务按项目计数降序排列。
  8. CPB构建FP-TreeFP-Tree中包含了表头项headers,每一个header都指向了一个链表HeaderLinkList,链表中的每个
  9. 元素都是FP-Tree上的一个节点,且节点名称与header.name相同。
  10. for header in headers:
  11. newPostModel=header.name+PostModel
  12. 把<newPostModel, header.count>加到FrequentPattens中。
  13. newCPB=[]
  14. for TreeNode in HeaderLinkList:
  15. 得到从FP-Tree的根节点到TreeNode的全路径path,把path作为一个事务添加到newCPB中,要重复添加TreeNode.count次。
  16. FPGrowth(newCPB,newPostModel)

算法的核心是FPGrowth函数,这是一个递归函数。CPB的全称是Conditional Pattern Base(条件模式基),我们可以把CPB理解为算法在不同阶段的事务集合。PostModel称为后缀模式,它是一个List。后文会详细讲CPB和PostModel是如何生成的,初始时令PostModel为空,令CPB就是原始的事务集合。

下面我们举个例子来详细讲解FPGrowth函数的完整实现。

事务数据库如下,一行表示一条购物记录:

牛奶,鸡蛋,面包,薯片
鸡蛋,爆米花,薯片,啤酒
鸡蛋,面包,薯片
牛奶,鸡蛋,面包,爆米花,薯片,啤酒
牛奶,面包,啤酒
鸡蛋,面包,啤酒
牛奶,面包,薯片
牛奶,鸡蛋,面包,黄油,薯片
牛奶,鸡蛋,黄油,薯片

令minSuport=3,统计每一个项目出现的次数,把次数低于minSuport的项目删除掉,剩下的项目按出现的次数降序排列,得到F1:
对于每一条事务,按照F1中的顺序重新排序,不在F1中的被删除掉。这样整个事务集合变为:

薯片,鸡蛋,面包,牛奶
薯片,鸡蛋,啤酒
薯片,鸡蛋,面包
薯片,鸡蛋,面包,牛奶,啤酒
面包,牛奶,啤酒
鸡蛋,面包,啤酒
薯片,面包,牛奶
薯片,鸡蛋,面包,牛奶
薯片,鸡蛋,牛奶

上面的事务集合即为当前的CPB,当前的PostModel依然为空。由CPB构建FP-Tree的步骤如下。
插入第一条事务(薯片,鸡蛋,面包,牛奶)之后

图片标题

插入第二条事务(薯片,鸡蛋,啤酒)
图片标题

插入第三条记录(面包,牛奶,啤酒)
图片标题

估计你也知道怎么插了,最终生成的FP-Tree是:

上图中左边的那一叫做表头项,树中相同名称的节点要链接起来,链表的第一个元素就是表头项里的元素。不论是表头项节点还是FP-Tree中有节点,它们至少有2个属性:name和count。

现在我们已进行完算法描述的第10行。go on

遍历表头项中的每一项,我们拿“牛奶:6”为例。

新的PostModel为“表头项+老的PostModel”,现在由于老的PostModel还是空list,所以新的PostModel为:[牛奶]。新的PostModel就是一条频繁模式,它的支持数即为表头项的count:6,所以此处可以输出一条频繁模式<[牛奶], 6>

从表头项“牛奶”开始,找到FP-Tree中所有的“牛奶”节点,然后找到从树的根节点到“牛奶”节点的路径。得到4条路径:

薯片:7,鸡蛋:6,牛奶:1
薯片:7,鸡蛋:6,面包:4,牛奶:3
薯片:7,面包:1,牛奶:1
面包:1,牛奶:1

对于每一条路径上的节点,其count都设置为牛奶的count:

薯片:1,鸡蛋:1,牛奶:1
薯片:3,鸡蛋:3,面包:3,牛奶:3
薯片:1,面包:1,牛奶:1
面包:1,牛奶:1

因为每一项末尾都是牛奶,可以把牛奶去掉,得到新的CPB:

薯片:1,鸡蛋:1
薯片:3,鸡蛋:3,面包:3
薯片:1,面包:1
面包:1

然后递归调用FPGrowth(新的CPB,新的PostModel),当发现新有CPB为空时递归就可以退出了。

3. 几点说明

  1. 可以在构建FP-Tree之前就把CPB中低于minSuport的项目删掉,也可以先不删,而是在构建FP-Tree的过程当中如果遇到低于minSuport的项目不把它插入到FP-Tree中就可以了。FP-Tree算法之所以高效,就是因为它在每次FPGrowth递归时都对数据进行了这种裁剪。
  2. 没必要每次FPGrowth递归时都把CPB中的事务按F1做一次重排序,只需要第一次构建CPB时按F1做一次排序,以后每次构建新的CPB时保持与老的CPB各项目顺序不变就可以了。
  3. 对于FP-Tree已经是单枝的情况,就没有必要再递归调用FPGrowth了,直接输出整条路径上所有节点的各种组合+postModel就可了。例如当FP-Tree为:
    图片标题

树上只有一条路径{A-B-C},在保证A-B-C这种顺序的前提下,这三个节点的所有组合是:A,B,C,AB,AC,BC,ABC。每一种组合与postModel拼接形成一条频繁模式,模式的支持数即为表头项的计数(单枝的情况下所有表头项和所有树节点的计数都是相同的)。

4. Java实现

  1. import java.util.List;
  2. public class StrongAssociationRule {
  3. public List<String> condition;
  4. public String result;
  5. public int support;
  6. public double confidence;
  7. }

TreeNode.java

  1. import java.util.ArrayList;
  2. import java.util.List;
  3. /**
  4. * @Description: FP树的节点
  5. * @Author orisun
  6. * @Date Jun 23, 2016
  7. */
  8. class TreeNode {
  9. /**
  10. * 节点名称
  11. **/
  12. private String name;
  13. /**
  14. * 频数
  15. **/
  16. private int count;
  17. private TreeNode parent;
  18. private List<TreeNode> children;
  19. /**
  20. * 下一个节点(由表头项维护的那个链表)
  21. **/
  22. private TreeNode nextHomonym;
  23. /**
  24. * 末节点(由表头项维护的那个链表)
  25. **/
  26. private TreeNode tail;
  27. @Override
  28. public String toString() {
  29. return name;
  30. }
  31. public TreeNode() {
  32. }
  33. public TreeNode(String name) {
  34. this.name = name;
  35. }
  36. public String getName() {
  37. return this.name;
  38. }
  39. public void setName(String name) {
  40. this.name = name;
  41. }
  42. public int getCount() {
  43. return this.count;
  44. }
  45. public void setCount(int count) {
  46. this.count = count;
  47. }
  48. public TreeNode getParent() {
  49. return this.parent;
  50. }
  51. public void setParent(TreeNode parent) {
  52. this.parent = parent;
  53. }
  54. public List<TreeNode> getChildren() {
  55. return this.children;
  56. }
  57. public void addChild(TreeNode child) {
  58. if (getChildren() == null) {
  59. List<TreeNode> list = new ArrayList<TreeNode>();
  60. list.add(child);
  61. setChildren(list);
  62. } else {
  63. getChildren().add(child);
  64. }
  65. }
  66. public TreeNode findChild(String name) {
  67. List<TreeNode> children = getChildren();
  68. if (children != null) {
  69. for (TreeNode child : children) {
  70. if (child.getName().equals(name)) {
  71. return child;
  72. }
  73. }
  74. }
  75. return null;
  76. }
  77. public void setChildren(List<TreeNode> children) {
  78. this.children = children;
  79. }
  80. public void printChildrenName() {
  81. List<TreeNode> children = getChildren();
  82. if (children != null) {
  83. for (TreeNode child : children)
  84. System.out.print(child.getName() + " ");
  85. } else
  86. System.out.print("null");
  87. }
  88. public TreeNode getNextHomonym() {
  89. return this.nextHomonym;
  90. }
  91. public void setNextHomonym(TreeNode nextHomonym) {
  92. this.nextHomonym = nextHomonym;
  93. }
  94. public void countIncrement(int n) {
  95. this.count += n;
  96. }
  97. public TreeNode getTail() {
  98. return tail;
  99. }
  100. public void setTail(TreeNode tail) {
  101. this.tail = tail;
  102. }
  103. }

FPTree.java

  1. import java.io.BufferedReader;
  2. import java.io.BufferedWriter;
  3. import java.io.FileReader;
  4. import java.io.FileWriter;
  5. import java.io.IOException;
  6. import java.text.DecimalFormat;
  7. import java.util.ArrayList;
  8. import java.util.Collections;
  9. import java.util.Comparator;
  10. import java.util.HashMap;
  11. import java.util.HashSet;
  12. import java.util.LinkedList;
  13. import java.util.List;
  14. import java.util.Map;
  15. import java.util.Map.Entry;
  16. import java.util.Set;
  17. /**
  18. * @Description: FPTree强关联规则挖掘算法
  19. * @Author orisun
  20. * @Date Jun 23, 2016
  21. */
  22. public class FPTree {
  23. /**
  24. * 频繁模式的最小支持数
  25. **/
  26. private int minSuport;
  27. /**
  28. * 关联规则的最小置信度
  29. **/
  30. private double confident;
  31. /**
  32. * 事务项的总数
  33. **/
  34. private int totalSize;
  35. /**
  36. * 存储每个频繁项及其对应的计数
  37. **/
  38. private Map<List<String>, Integer> frequentMap = new HashMap<List<String>, Integer>();
  39. /**
  40. * 关联规则中,哪些项可作为被推导的结果,默认情况下所有项都可以作为被推导的结果
  41. **/
  42. private Set<String> decideAttr = null;
  43. public int getMinSuport() {
  44. return this.minSuport;
  45. }
  46. /**
  47. * 设置最小支持数
  48. *
  49. * @param minSuport
  50. */
  51. public void setMinSuport(int minSuport) {
  52. this.minSuport = minSuport;
  53. }
  54. public double getConfident() {
  55. return confident;
  56. }
  57. /**
  58. * 设置最小置信度
  59. *
  60. * @param confident
  61. */
  62. public void setConfident(double confident) {
  63. this.confident = confident;
  64. }
  65. /**
  66. * 设置决策属性。如果要调用{@linkplain #readTransRocords(String[])},需要在调用{@code readTransRocords}
  67. * 之后再调用{@code setDecideAttr}
  68. *
  69. * @param decideAttr
  70. */
  71. public void setDecideAttr(Set<String> decideAttr) {
  72. this.decideAttr = decideAttr;
  73. }
  74. /**
  75. * 获取频繁项集
  76. *
  77. * @return
  78. * @Description:
  79. */
  80. public Map<List<String>, Integer> getFrequentItems() {
  81. return frequentMap;
  82. }
  83. public int getTotalSize() {
  84. return totalSize;
  85. }
  86. /**
  87. * 根据一条频繁模式得到若干关联规则
  88. *
  89. * @param list
  90. * @return
  91. */
  92. private List<StrongAssociationRule> getRules(List<String> list) {
  93. List<StrongAssociationRule> rect = new LinkedList<StrongAssociationRule>();
  94. if (list.size() > 1) {
  95. for (int i = 0; i < list.size(); i++) {
  96. String result = list.get(i);
  97. if (decideAttr.contains(result)) {
  98. List<String> condition = new ArrayList<String>();
  99. condition.addAll(list.subList(0, i));
  100. condition.addAll(list.subList(i + 1, list.size()));
  101. StrongAssociationRule rule = new StrongAssociationRule();
  102. rule.condition = condition;
  103. rule.result = result;
  104. rect.add(rule);
  105. }
  106. }
  107. }
  108. return rect;
  109. }
  110. /**
  111. * 从若干个文件中读入Transaction Record,同时把所有项设置为decideAttr
  112. *
  113. * @param filenames
  114. * @return
  115. * @Description:
  116. */
  117. public List<List<String>> readTransRocords(String[] filenames) {
  118. Set<String> set = new HashSet<String>();
  119. List<List<String>> transaction = null;
  120. if (filenames.length > 0) {
  121. transaction = new LinkedList<List<String>>();
  122. for (String filename : filenames) {
  123. try {
  124. FileReader fr = new FileReader(filename);
  125. BufferedReader br = new BufferedReader(fr);
  126. try {
  127. String line = null;
  128. // 一项事务占一行
  129. while ((line = br.readLine()) != null) {
  130. if (line.trim().length() > 0) {
  131. // 每个item之间用","分隔
  132. String[] str = line.split(",");
  133. //每一项事务中的重复项需要排重
  134. Set<String> record = new HashSet<String>();
  135. for (String w : str) {
  136. record.add(w);
  137. set.add(w);
  138. }
  139. List<String> rl = new ArrayList<String>();
  140. rl.addAll(record);
  141. transaction.add(rl);
  142. }
  143. }
  144. } finally {
  145. br.close();
  146. }
  147. } catch (IOException ex) {
  148. System.out.println("Read transaction records failed." + ex.getMessage());
  149. System.exit(1);
  150. }
  151. }
  152. }
  153. this.setDecideAttr(set);
  154. return transaction;
  155. }
  156. /**
  157. * 生成一个序列的各种子序列。(序列是有顺序的)
  158. *
  159. * @param residualPath
  160. * @param results
  161. */
  162. private void combine(LinkedList<TreeNode> residualPath, List<List<TreeNode>> results) {
  163. if (residualPath.size() > 0) {
  164. //如果residualPath太长,则会有太多的组合,内存会被耗尽的
  165. TreeNode head = residualPath.poll();
  166. List<List<TreeNode>> newResults = new ArrayList<List<TreeNode>>();
  167. for (List<TreeNode> list : results) {
  168. List<TreeNode> listCopy = new ArrayList<TreeNode>(list);
  169. newResults.add(listCopy);
  170. }
  171. for (List<TreeNode> newPath : newResults) {
  172. newPath.add(head);
  173. }
  174. results.addAll(newResults);
  175. List<TreeNode> list = new ArrayList<TreeNode>();
  176. list.add(head);
  177. results.add(list);
  178. combine(residualPath, results);
  179. }
  180. }
  181. private boolean isSingleBranch(TreeNode root) {
  182. boolean rect = true;
  183. while (root.getChildren() != null) {
  184. if (root.getChildren().size() > 1) {
  185. rect = false;
  186. break;
  187. }
  188. root = root.getChildren().get(0);
  189. }
  190. return rect;
  191. }
  192. /**
  193. * 计算事务集中每一项的频数
  194. *
  195. * @param transRecords
  196. * @return
  197. */
  198. private Map<String, Integer> getFrequency(List<List<String>> transRecords) {
  199. Map<String, Integer> rect = new HashMap<String, Integer>();
  200. for (List<String> record : transRecords) {
  201. for (String item : record) {
  202. Integer cnt = rect.get(item);
  203. if (cnt == null) {
  204. cnt = new Integer(0);
  205. }
  206. rect.put(item, ++cnt);
  207. }
  208. }
  209. return rect;
  210. }
  211. /**
  212. * 根据事务集合构建FPTree
  213. *
  214. * @param transRecords
  215. * @Description:
  216. */
  217. public void buildFPTree(List<List<String>> transRecords) {
  218. totalSize = transRecords.size();
  219. //计算每项的频数
  220. final Map<String, Integer> freqMap = getFrequency(transRecords);
  221. //先把频繁1项集添加到频繁模式中
  222. // for (Entry<String, Integer> entry : freqMap.entrySet()) {
  223. // String name = entry.getKey();
  224. // int cnt = entry.getValue();
  225. // if (cnt >= minSuport) {
  226. // List<String> rule = new ArrayList<String>();
  227. // rule.add(name);
  228. // frequentMap.put(rule, cnt);
  229. // }
  230. // }
  231. //每条事务中的项按F1排序
  232. for (List<String> transRecord : transRecords) {
  233. Collections.sort(transRecord, new Comparator<String>() {
  234. @Override
  235. public int compare(String o1, String o2) {
  236. return freqMap.get(o2) - freqMap.get(o1);
  237. }
  238. });
  239. }
  240. FPGrowth(transRecords, null);
  241. }
  242. /**
  243. * FP树递归生长,从而得到所有的频繁模式
  244. *
  245. * @param cpb 条件模式基
  246. * @param postModel 后缀模式
  247. */
  248. private void FPGrowth(List<List<String>> cpb, LinkedList<String> postModel) {
  249. // System.out.println("CPB is");
  250. // for (List<String> records : cpb) {
  251. // System.out.println(records);
  252. // }
  253. // System.out.println("PostPattern is " + postPattern);
  254. Map<String, Integer> freqMap = getFrequency(cpb);
  255. Map<String, TreeNode> headers = new HashMap<String, TreeNode>();
  256. for (Entry<String, Integer> entry : freqMap.entrySet()) {
  257. String name = entry.getKey();
  258. int cnt = entry.getValue();
  259. //每一次递归时都有可能出现一部分模式的频数低于阈值
  260. if (cnt >= minSuport) {
  261. TreeNode node = new TreeNode(name);
  262. node.setCount(cnt);
  263. headers.put(name, node);
  264. }
  265. }
  266. TreeNode treeRoot = buildSubTree(cpb, freqMap, headers);
  267. //如果只剩下虚根节点,则递归结束
  268. if ((treeRoot.getChildren() == null) || (treeRoot.getChildren().size() == 0)) {
  269. return;
  270. }
  271. //如果树是单枝的,则直接把“路径的各种组合+后缀模式”添加到频繁模式集中。这个技巧是可选的,即跳过此步进入下一轮
  272. //递归也可以得到正确的结果
  273. if (isSingleBranch(treeRoot)) {
  274. LinkedList<TreeNode> path = new LinkedList<TreeNode>();
  275. TreeNode currNode = treeRoot;
  276. while (currNode.getChildren() != null) {
  277. currNode = currNode.getChildren().get(0);
  278. path.add(currNode);
  279. }
  280. //调用combine时path不宜过长,否则会OutOfMemory
  281. if (path.size() <= 20) {
  282. List<List<TreeNode>> results = new ArrayList<List<TreeNode>>();
  283. combine(path, results);
  284. for (List<TreeNode> list : results) {
  285. int cnt = 0;
  286. List<String> rule = new ArrayList<String>();
  287. for (TreeNode node : list) {
  288. rule.add(node.getName());
  289. cnt = node.getCount();//cnt最FPTree叶节点的计数
  290. }
  291. if (postModel != null) {
  292. rule.addAll(postModel);
  293. }
  294. frequentMap.put(rule, cnt);
  295. }
  296. return;
  297. } else {
  298. System.err.println("length of path is too long: " + path.size());
  299. }
  300. }
  301. for (TreeNode header : headers.values()) {
  302. List<String> rule = new ArrayList<String>();
  303. rule.add(header.getName());
  304. if (postModel != null) {
  305. rule.addAll(postModel);
  306. }
  307. //表头项+后缀模式 构成一条频繁模式(频繁模式内部也是按照F1排序的),频繁度为表头项的计数
  308. frequentMap.put(rule, header.getCount());
  309. //新的后缀模式:表头项+上一次的后缀模式(注意保持顺序,始终按F1的顺序排列)
  310. LinkedList<String> newPostPattern = new LinkedList<String>();
  311. newPostPattern.add(header.getName());
  312. if (postModel != null) {
  313. newPostPattern.addAll(postModel);
  314. }
  315. //新的条件模式基
  316. List<List<String>> newCPB = new LinkedList<List<String>>();
  317. TreeNode nextNode = header;
  318. while ((nextNode = nextNode.getNextHomonym()) != null) {
  319. int counter = nextNode.getCount();
  320. //获得从虚根节点(不包括虚根节点)到当前节点(不包括当前节点)的路径,即一条条件模式基。注意保持顺序:你节点
  321. //在前,子节点在后,即始终保持频率高的在前
  322. LinkedList<String> path = new LinkedList<String>();
  323. TreeNode parent = nextNode;
  324. while ((parent = parent.getParent()).getName() != null) {//虚根节点的name为null
  325. path.push(parent.getName());//往表头插入
  326. }
  327. //事务要重复添加counter次
  328. while (counter-- > 0) {
  329. newCPB.add(path);
  330. }
  331. }
  332. FPGrowth(newCPB, newPostPattern);
  333. }
  334. }
  335. /**
  336. * 把所有事务插入到一个FP树当中
  337. *
  338. * @param transRecords
  339. * @param F1
  340. * @return
  341. */
  342. private TreeNode buildSubTree(List<List<String>> transRecords,
  343. final Map<String, Integer> freqMap,
  344. final Map<String, TreeNode> headers) {
  345. TreeNode root = new TreeNode();//虚根节点
  346. for (List<String> transRecord : transRecords) {
  347. LinkedList<String> record = new LinkedList<String>(transRecord);
  348. TreeNode subTreeRoot = root;
  349. TreeNode tmpRoot = null;
  350. if (root.getChildren() != null) {
  351. //延已有的分支,令各节点计数加1
  352. while (!record.isEmpty()
  353. && (tmpRoot = subTreeRoot.findChild(record.peek())) != null) {
  354. tmpRoot.countIncrement(1);
  355. subTreeRoot = tmpRoot;
  356. record.poll();
  357. }
  358. }
  359. //长出新的节点
  360. addNodes(subTreeRoot, record, headers);
  361. }
  362. return root;
  363. }
  364. /**
  365. * 往特定的节点下插入一串后代节点,同时维护表头项到同名节点的链表指针
  366. *
  367. * @param ancestor
  368. * @param record
  369. * @param headers
  370. */
  371. private void addNodes(TreeNode ancestor, LinkedList<String> record,
  372. final Map<String, TreeNode> headers) {
  373. while (!record.isEmpty()) {
  374. String item = (String) record.poll();
  375. //单个项的出现频数必须大于最小支持数,否则不允许插入FP树。达到最小支持度的项都在headers中。每一次递归根据条件
  376. //模式基本建立新的FPTree时,把要把频数低于minSuport的排除在外,这也正是FPTree比穷举法快的真正原因
  377. if (headers.containsKey(item)) {
  378. TreeNode leafnode = new TreeNode(item);
  379. leafnode.setCount(1);
  380. leafnode.setParent(ancestor);
  381. ancestor.addChild(leafnode);
  382. TreeNode header = headers.get(item);
  383. TreeNode tail = header.getTail();
  384. if (tail != null) {
  385. tail.setNextHomonym(leafnode);
  386. } else {
  387. header.setNextHomonym(leafnode);
  388. }
  389. header.setTail(leafnode);
  390. addNodes(leafnode, record, headers);
  391. }
  392. // else {
  393. // System.err.println(item + " is not F1");
  394. // }
  395. }
  396. }
  397. /**
  398. * 获取所有的强规则
  399. *
  400. * @return
  401. */
  402. public List<StrongAssociationRule> getAssociateRule() {
  403. assert totalSize > 0;
  404. List<StrongAssociationRule> rect = new ArrayList<StrongAssociationRule>();
  405. //遍历所有频繁模式
  406. for (Entry<List<String>, Integer> entry : frequentMap.entrySet()) {
  407. List<String> items = entry.getKey();
  408. int count1 = entry.getValue();
  409. //一条频繁模式可以生成很多关联规则
  410. List<StrongAssociationRule> rules = getRules(items);
  411. //计算每一条关联规则的支持度和置信度
  412. for (StrongAssociationRule rule : rules) {
  413. if (frequentMap.containsKey(rule.condition)) {
  414. int count2 = frequentMap.get(rule.condition);
  415. double confidence = 1.0 * count1 / count2;
  416. if (confidence >= this.confident) {
  417. rule.support = count1;
  418. rule.confidence = confidence;
  419. rect.add(rule);
  420. }
  421. } else {
  422. System.err.println(rule.condition + " is not a frequent pattern, however "
  423. + items + " is a frequent pattern");
  424. }
  425. }
  426. }
  427. return rect;
  428. }
  429. public static void main(String[] args) throws IOException {
  430. String infile = "trolley.txt";
  431. FPTree fpTree = new FPTree();
  432. fpTree.setConfident(0.6);
  433. fpTree.setMinSuport(3);
  434. if (args.length >= 2) {
  435. double confidence = Double.parseDouble(args[0]);
  436. int suport = Integer.parseInt(args[1]);
  437. fpTree.setConfident(confidence);
  438. fpTree.setMinSuport(suport);
  439. }
  440. List<List<String>> trans = fpTree.readTransRocords(new String[]{infile});
  441. Set<String> decideAttr = new HashSet<String>();
  442. decideAttr.add("鸡蛋");
  443. decideAttr.add("面包");
  444. fpTree.setDecideAttr(decideAttr);
  445. long begin = System.currentTimeMillis();
  446. fpTree.buildFPTree(trans);
  447. long end = System.currentTimeMillis();
  448. System.out.println("buildFPTree use time " + (end - begin));
  449. Map<List<String>, Integer> pattens = fpTree.getFrequentItems();
  450. String outfile = "pattens.txt";
  451. BufferedWriter bw = new BufferedWriter(new FileWriter(outfile));
  452. System.out.println("模式\t频数");
  453. bw.write("模式\t频数");
  454. bw.newLine();
  455. for (Entry<List<String>, Integer> entry : pattens.entrySet()) {
  456. System.out.println(entry.getKey() + "\t" + entry.getValue());
  457. bw.write(joinList(entry.getKey()) + "\t" + entry.getValue());
  458. bw.newLine();
  459. }
  460. bw.close();
  461. System.out.println();
  462. List<StrongAssociationRule> rules = fpTree.getAssociateRule();
  463. outfile = "rule.txt";
  464. bw = new BufferedWriter(new FileWriter(outfile));
  465. System.out.println("条件\t结果\t支持度\t置信度");
  466. bw.write("条件\t结果\t支持度\t置信度");
  467. bw.newLine();
  468. DecimalFormat dfm = new DecimalFormat("#.##");
  469. for (StrongAssociationRule rule : rules) {
  470. System.out.println(rule.condition + "->" + rule.result + "\t" + dfm.format(rule.support)
  471. + "\t" + dfm.format(rule.confidence));
  472. bw.write(rule.condition + "->" + rule.result + "\t" + dfm.format(rule.support) + "\t"
  473. + dfm.format(rule.confidence));
  474. bw.newLine();
  475. }
  476. bw.close();
  477. }
  478. private static String joinList(List<String> list) {
  479. if (list == null || list.size() == 0) {
  480. return "";
  481. }
  482. StringBuilder sb = new StringBuilder();
  483. for (String ele : list) {
  484. sb.append(ele);
  485. sb.append(",");
  486. }
  487. //把最后一个逗号去掉
  488. return sb.substring(0, sb.length() - 1);
  489. }
  490. }

输入trolley.txt

  1. 牛奶,鸡蛋,面包,薯片
  2. 鸡蛋,爆米花,薯片,啤酒
  3. 鸡蛋,面包,薯片
  4. 牛奶,鸡蛋,面包,爆米花,薯片,啤酒
  5. 牛奶,面包,啤酒
  6. 鸡蛋,面包,啤酒
  7. 牛奶,面包,薯片
  8. 牛奶,鸡蛋,面包,黄油,薯片
  9. 牛奶,鸡蛋,黄油,薯片

输出pattens.txt

  1. 模式 频数
  2. 面包,啤酒 3
  3. 鸡蛋,牛奶 4
  4. 面包,薯片 5
  5. 薯片,鸡蛋 6
  6. 啤酒 4
  7. 薯片 7
  8. 面包,薯片,鸡蛋,牛奶 3
  9. 鸡蛋,啤酒 3
  10. 面包,牛奶 5
  11. 薯片,鸡蛋,牛奶 4
  12. 面包,鸡蛋,牛奶 3
  13. 面包 7
  14. 牛奶 6
  15. 面包,薯片,鸡蛋 4
  16. 薯片,牛奶 5
  17. 鸡蛋 7
  18. 面包,鸡蛋 5
  19. 面包,薯片,牛奶 4

输出rule.txt

  1. 条件 结果 支持度 置信度
  2. [啤酒]->面包 3 0.75
  3. [牛奶]->鸡蛋 4 0.67
  4. [薯片]->面包 5 0.71
  5. [薯片]->鸡蛋 6 0.86
  6. [薯片, 鸡蛋, 牛奶]->面包 3 0.75
  7. [面包, 薯片, 牛奶]->鸡蛋 3 0.75
  8. [啤酒]->鸡蛋 3 0.75
  9. [牛奶]->面包 5 0.83
  10. [薯片, 牛奶]->鸡蛋 4 0.8
  11. [鸡蛋, 牛奶]->面包 3 0.75
  12. [面包, 牛奶]->鸡蛋 3 0.6
  13. [薯片, 鸡蛋]->面包 4 0.67
  14. [面包, 薯片]->鸡蛋 4 0.8
  15. [鸡蛋]->面包 5 0.71
  16. [面包]->鸡蛋 5 0.71
  17. [薯片, 牛奶]->面包 4 0.8
添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注