[关闭]
@spiritnotes 2016-03-01T08:13:57.000000Z 字数 3789 阅读 1729

Apriori算法

机器学习


关联分析

在购物网站等分析中经常需要分析用户购买的物品之间的关系,被称为关联规则学习(association rule learning)或关联分析(association analysis)。

在关联分析中,经常出现在一起的物品的集合称为频繁项集(frequent item sets),其由一个支持度(support)指标用以衡量。关联规则(例如:{尿布}{啤酒})用以定义物品者之间存在很强的关系,其由可信度(confidence)衡量。

支持度(support)
一个项集的支持度(support)被定义为数据集中包含该项集的记录所占的比例。支持度是针对项集来说的,因此可以定义一个最小支持度,而只保留满足最小支持度的项集。
尿尿
可信度或置信度(confidence)
是针对一条诸如{尿布}{葡萄酒}的关联规则来定义的。其计算公式为
尿尿尿

Apriori算法

Apriori的重点是Apriori原理,即是说某个项集是频繁的,那么它的所有子集也是频繁的。反过来说,如果一个项集是非频繁集,那么它所有的超集也是非频繁的。

这样在进行频繁集生成的时候,就可以在当前已经判断为频繁的频繁集上进行扩展已生成新的频繁项集。

其主要流程如下:

在代码中生成k+1集时,将每一个项集排序,只有前k-1个元素相等时才合并,可以简化操作。假设能够生成某k+1项集,则当前项集集合中会存在k+1个项集,而这k+1个项集是只对应于该k+1项集,如果L1,L2为排序的话,则只有一种会存在L1和L2中前k-1相同,即[1,...,k-1,k]和[1,...,k-1,k+1],而其他情况下两者的不同不会是最后一位。

Python实现

github地址: https://github.com/spiritwiki/codes/tree/master/apriori

  1. import logging
  2. #logging.basicConfig(level=logging.DEBUG)
  3. class Apriori():
  4. def fit(self, trans, min_support, min_confidence):
  5. """对事务进行训练,获得对应的频繁集和关联规则"""
  6. if len(trans) == 0:
  7. return [],[]
  8. frequent_sets = self.calc_frequent_sets(trans, min_support)
  9. association_rules = self.calc_association_rules(frequent_sets, trans, min_confidence)
  10. return frequent_sets, association_rules
  11. def calc_association_rules(self, freq_sets_length, trans, min_confidence):
  12. '''计算关联规则'''
  13. rules = []
  14. for freq_sets in freq_sets_length[1:]:
  15. for freq_set in freq_sets:
  16. temp = self._calc_rules_4_set(freq_set, trans, min_confidence)
  17. rules += temp
  18. return rules
  19. def _calc_rules_4_set(self, freq_set, trans, min_confidence):
  20. '''针对单个频繁集计算其关联规则'''
  21. rules = []
  22. curr_after_items = [frozenset([i]) for i in freq_set]
  23. while curr_after_items:
  24. filter_afters = []
  25. for after in curr_after_items:
  26. before = freq_set - after
  27. confidence = self._calc_itemset_support(trans, freq_set) / self._calc_itemset_support(trans, before)
  28. if confidence >= min_confidence:
  29. rules.append([before, after])
  30. filter_afters.append(after)
  31. curr_after_items = self._get_next_level_freqset(filter_afters)
  32. return rules
  33. def _filter_freq_sets_with_support(self, freq_sets, trans, min_support):
  34. '''根据support对当前的频繁集进行筛选'''
  35. filter_sets = []
  36. for freq_set in freq_sets:
  37. support = self._calc_itemset_support(trans, freq_set)
  38. if support >= min_support:
  39. filter_sets.append(freq_set)
  40. return filter_sets
  41. def calc_frequent_sets(self, trans, support_ratio):
  42. """获得频繁项集"""
  43. frequent_sets_all = []
  44. logger = logging.getLogger('calc_frequent_sets')
  45. curr_freq_sets = self._create_all_1_item_sets(trans)
  46. while len(curr_freq_sets) :
  47. logger.debug('-> create freq sets {}'.format(curr_freq_sets))
  48. filtered_freq_set = self._filter_freq_sets_with_support(curr_freq_sets, trans, support_ratio)
  49. frequent_sets_all.append(filtered_freq_set)
  50. logger.debug('-> filted freq sets {}'.format(filtered_freq_set))
  51. curr_freq_sets = self._get_next_level_freqset(filtered_freq_set)
  52. logger.debug('-> all freq sets {}'.format(frequent_sets_all))
  53. return frequent_sets_all
  54. def _get_next_level_freqset(self, freq_set):
  55. '''根据当前的频繁集组成长度+1的频繁集'''
  56. next_level_freq_set = []
  57. freq_set_num = len(freq_set)
  58. for i in range(0, freq_set_num):
  59. for j in range(i+1, freq_set_num):
  60. list_i = list(freq_set[i])
  61. list_j = list(freq_set[j])
  62. list_i.sort(), list_j.sort()
  63. if list_i[:-1] == list_j[:-1]:
  64. next_level_freq_set.append(frozenset(list_i + list_j))
  65. return next_level_freq_set
  66. def _calc_itemset_support(self, trans, item_set):
  67. '''计算support'''
  68. hit_num = sum(item_set.issubset(tran) for tran in trans)
  69. support_ratio = hit_num / len(trans)
  70. logging.debug('calc support for {} hit:{} value:{}'.format(item_set, hit_num, support_ratio))
  71. return support_ratio
  72. def _create_all_1_item_sets(self, trans):
  73. '''创建所有长度为1的频繁集'''
  74. item_sets = set()
  75. for tran in trans:
  76. for item in tran:
  77. item_sets.add(frozenset([item]))
  78. return list(item_sets)
添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注