@spiritnotes
2016-03-01T08:13:57.000000Z
字数 3789
阅读 1864
机器学习
在购物网站等分析中经常需要分析用户购买的物品之间的关系,被称为关联规则学习(association rule learning)或关联分析(association analysis)。
在关联分析中,经常出现在一起的物品的集合称为频繁项集(frequent item sets),其由一个支持度(support)指标用以衡量。关联规则(例如:{尿布}{啤酒})用以定义物品者之间存在很强的关系,其由可信度(confidence)衡量。
Apriori的重点是Apriori原理,即是说某个项集是频繁的,那么它的所有子集也是频繁的。反过来说,如果一个项集是非频繁集,那么它所有的超集也是非频繁的。
其主要流程如下:
在代码中生成k+1集时,将每一个项集排序,只有前k-1个元素相等时才合并,可以简化操作。假设能够生成某k+1项集,则当前项集集合中会存在k+1个项集,而这k+1个项集是只对应于该k+1项集,如果L1,L2为排序的话,则只有一种会存在L1和L2中前k-1相同,即[1,...,k-1,k]和[1,...,k-1,k+1],而其他情况下两者的不同不会是最后一位。
github地址: https://github.com/spiritwiki/codes/tree/master/apriori
import logging
#logging.basicConfig(level=logging.DEBUG)
class Apriori():
def fit(self, trans, min_support, min_confidence):
"""对事务进行训练,获得对应的频繁集和关联规则"""
if len(trans) == 0:
return [],[]
frequent_sets = self.calc_frequent_sets(trans, min_support)
association_rules = self.calc_association_rules(frequent_sets, trans, min_confidence)
return frequent_sets, association_rules
def calc_association_rules(self, freq_sets_length, trans, min_confidence):
'''计算关联规则'''
rules = []
for freq_sets in freq_sets_length[1:]:
for freq_set in freq_sets:
temp = self._calc_rules_4_set(freq_set, trans, min_confidence)
rules += temp
return rules
def _calc_rules_4_set(self, freq_set, trans, min_confidence):
'''针对单个频繁集计算其关联规则'''
rules = []
curr_after_items = [frozenset([i]) for i in freq_set]
while curr_after_items:
filter_afters = []
for after in curr_after_items:
before = freq_set - after
confidence = self._calc_itemset_support(trans, freq_set) / self._calc_itemset_support(trans, before)
if confidence >= min_confidence:
rules.append([before, after])
filter_afters.append(after)
curr_after_items = self._get_next_level_freqset(filter_afters)
return rules
def _filter_freq_sets_with_support(self, freq_sets, trans, min_support):
'''根据support对当前的频繁集进行筛选'''
filter_sets = []
for freq_set in freq_sets:
support = self._calc_itemset_support(trans, freq_set)
if support >= min_support:
filter_sets.append(freq_set)
return filter_sets
def calc_frequent_sets(self, trans, support_ratio):
"""获得频繁项集"""
frequent_sets_all = []
logger = logging.getLogger('calc_frequent_sets')
curr_freq_sets = self._create_all_1_item_sets(trans)
while len(curr_freq_sets) :
logger.debug('-> create freq sets {}'.format(curr_freq_sets))
filtered_freq_set = self._filter_freq_sets_with_support(curr_freq_sets, trans, support_ratio)
frequent_sets_all.append(filtered_freq_set)
logger.debug('-> filted freq sets {}'.format(filtered_freq_set))
curr_freq_sets = self._get_next_level_freqset(filtered_freq_set)
logger.debug('-> all freq sets {}'.format(frequent_sets_all))
return frequent_sets_all
def _get_next_level_freqset(self, freq_set):
'''根据当前的频繁集组成长度+1的频繁集'''
next_level_freq_set = []
freq_set_num = len(freq_set)
for i in range(0, freq_set_num):
for j in range(i+1, freq_set_num):
list_i = list(freq_set[i])
list_j = list(freq_set[j])
list_i.sort(), list_j.sort()
if list_i[:-1] == list_j[:-1]:
next_level_freq_set.append(frozenset(list_i + list_j))
return next_level_freq_set
def _calc_itemset_support(self, trans, item_set):
'''计算support'''
hit_num = sum(item_set.issubset(tran) for tran in trans)
support_ratio = hit_num / len(trans)
logging.debug('calc support for {} hit:{} value:{}'.format(item_set, hit_num, support_ratio))
return support_ratio
def _create_all_1_item_sets(self, trans):
'''创建所有长度为1的频繁集'''
item_sets = set()
for tran in trans:
for item in tran:
item_sets.add(frozenset([item]))
return list(item_sets)