@spiritnotes
2016-03-01T13:48:32.000000Z
字数 3650
阅读 2326
机器学习
FP-growth算法是用于发现频繁项集的,其基本过程如下:
一个元素可以在FP树中出现多次。FP树会存储项集的出现频率,而每个项集以路径的方式存储在树中。存在相似元素的集合会共享树的一部分。只有当集合完全不同时才会分叉。树节点上给出集合中单个元素以及其在序列中出现的次数,路径给出该序列的出现次数。
创建一个树node类
在构建前,需要按照绝对出现频率排序。从空集开始,向其不断添加频繁项集。过滤、排序后的事务依次添加到树中,如果树中已存在现有元素,则增加现有元素的值,如果现有元素不存在,则向树中添加一个分支。
github地址: https://github.com/spiritwiki/codes/tree/master/FP-Tree
import logging
import itertools
import collections
#logging.basicConfig(level=logging.DEBUG)
class FrequentPatternTree():
'''FP树'''
class Node():
'''节点定义,头表和树中都使用该节点'''
def __init__(self, name=None, parent=None, value=0):
self.name = name
self.parent = parent
self.value = value
self.child = {}
self.next = None
def link(self, new_node):
node = self
while node.next:
node = node.next
node.next = new_node
def increase(self, value=1):
self.value += value
def __eq__(self, other):
return (self.name == other.name and self.value == other.value)
def __repr__(self):
return str(self.value)+(str(self.child) if self.child else '')
def __eq__(self, other):
'''定义FP树的相等,test中使用'''
return self.root_node == other.root_node and self.header_table == other.header_table
def __repr__(self):
'''字符串表示'''
return '{table:' + str(self.header_table) + ' ; tree:' + str(self.root_node) + '}'
def _is_empty(self):
return len(self.header_table) == 0
@staticmethod
def mine_trans(trans, support_ratio=0.5, support=None):
'''针对事务集获取频繁集'''
support = support or len(trans)*support_ratio
tree = FrequentPatternTree().create_tree(trans, support=support)
return tree.get_freq_sets(support)
def create_tree(self, trans, support_ratio=0.5, support=None):
'''创建FP树'''
support = support or len(trans)*support_ratio
header_table = self._create_header_table(trans, support)
root_node = FrequentPatternTree.Node(None, None)
for tran in trans:
logging.debug('before transform tran: {}'.format(tran))
tran = [i for i in tran if i in header_table]
if not tran:
continue
tran.sort(key=lambda i:(header_table[i].value,i), reverse=True)
logging.debug('after transform tran: {}'.format(tran))
self._add_item_fptree(tran, root_node, header_table)
logging.debug('after add item tree: {}'.format(root_node))
self.root_node = root_node
self.header_table = header_table
return self
def _add_item_fptree(self, tran, node, header_table):
'''添加一个事务到FP树中'''
if not tran: return
item = tran[0]
if item not in node.child:
new_node = FrequentPatternTree.Node(item, node)
node.child[item] = new_node
header_table[item].link(new_node)
node.child[item].increase()
self._add_item_fptree(tran[1:], node.child[item], header_table)
def _create_header_table(self, trans, support):
'''创建头表'''
ret = collections.defaultdict(int)
for tran in trans:
for i in tran:
ret[i] += 1
filtered = {key:FrequentPatternTree.Node(key,value=value) for key,value in ret.items() if value >= support}
return filtered
def get_freq_sets(self, support):
'''在创建的fp树上获取频繁集'''
freq_sets = []
self._find_freq_sets_byprefix(support, [], freq_sets)
return freq_sets
def _find_freq_sets_byprefix(self, support, pre_fix, freq_sets):
'''通过构建条件FP树获取频繁集'''
items = [i[0] for i in sorted(self.header_table.items(), key=lambda i:(i[1].value,i)) if i[1].value >= support]
for item in items:
new_freq_set = pre_fix.copy()
new_freq_set.append(item)
freq_sets.append(new_freq_set)
condition_bases = self._find_prefix_path(item)
new_tree = FrequentPatternTree().create_tree(condition_bases, support=support)
if not new_tree._is_empty():
new_tree._find_freq_sets_byprefix(support, new_freq_set, freq_sets)
def _find_prefix_path(self, item):
'''获取条件路径'''
condition_path = []
node = self.header_table[item].next
while node:
path, curr_node = [], node
while True:
curr_node = curr_node.parent
if curr_node.name is None:
break
path.append(curr_node.name)
if path:
condition_path += [path]*node.value
node = node.next
return condition_path