[关闭]
@tianxingjian 2020-11-21T15:31:55.000000Z 字数 19604 阅读 1862

机器学习

《Machine Learning in Action》—— 小朋友,快来玩啊,决策树呦

在上篇文章中,《Machine Learning in Action》—— Taoye给你讲讲决策树到底是支什么“鬼”主要讲述了决策树的理论内容,介绍了什么决策树,以及生成决策树时所需要优先选取的三种决策标准。有学习的过SVM,或阅读过Taoye之前写的几篇SVM内容的文章可以发现,决策树相对于SVM来讲要简单很多,没有太多且复杂的公式推导。

我们在把之前的内容稍微回顾下:

有了前面的内容做基础,阅读本篇文章就会很轻松了。本篇主要讲述以下几部分的内容:

一、基于ID3算法手动构建决策树,并通过Matplotlib进行可视化

构建决策树的算法有好几种,比如像ID3、C4.5、CART之类的,限于篇幅、时间和精力的关系,本篇文章主要采用ID3算法来进行构建,使用到的决策标准(指标)是上篇文章中所提到的信息增益。关于C4.5和CART算法构建决策树,有兴趣的读者可以参考上期中的增益率和尼基指数的内容。

本次构建决策树所需要用到的数据集仍然是李航——《统计学习方法》中的贷款数据,这里再次把数据集放出来瞅瞅:

前面也有提到,ID3算法主要是基于信息增益作为选取属性特征的准则,在上期我们也计算过各个属性特征所对应的信息增益值,如下:

而根据ID3算法过程,我们可以知道,需要优先选取信息增益最大的属性进行决策,即房子,也就是说将房子作为决策树的根节点。由于房子这一个属性特征有两个取值,所以引申出两个子节点,一个对应“有房子”,另一个对应“无房子”,而“有房子”的六个样本的类别都是允许放款,也就是有同一类样本点,所以它理应成为一个叶子节点,且节点的类不妨标记为“允许”。

这样一来,我们的根节点以及其中的一个叶子节点就确定了。接下来,我们需要将“无房子”所对应的样本集再次选取一个新的属性特征进行决策。注意:此次做决策的数据样本总体就不再是初始数据了,而是“无房子”所对应的所有样本,这一点需要格外注意。

我们在“无房子的”的数据样本中再次计算其他属性所对应的信息增益,我们不妨讲此次的数据样本集合记为,计算结果如下:

与上同样分析,可以发现此时工作所对应的信息增益值最高,也就是说第二个优先选取的属性为“工作”。而且,我们可以发现,在数据样本集中,总共有9个,其中允许放款的有三个,拒绝的有6个,且结果标签与工作值刚好完全都对应上了,也就是说有工作的都允许放款了,没工作的都拒绝放款了。所以,在第二个属性特征选取完成之后,此时产生了俩个叶子节点,节点结果与是否有工作对应。

通过如上分析,我们就得到了此次基于ID3算法所构建出的决策树,决策树如下:

接下来我们通过代码来生成这颗决策树,对于树形结构的数据,我们可以通过字典或者说是json类型来进行保存。比如上图中的决策树,我们可以通过如下结果来进行表示:

  1. {"房子": {
  2. "1": "Y",
  3. "0":{"工作": {
  4. "1": "Y",
  5. "0": "N"
  6. }}
  7. }}

上述表示的数据格式我们一般称其为Json,这个在前后端、爬虫,亦或是在其他各种领域中都是接触的非常多的。另外,我们可以发现在决策树生成的过程中,在一个属性特征选取完成之后,需要经过同样的操作再次选取一个属性特征,其实就相当于一个周期,换句话讲正好满足了递归的特性,只是我们的数据总体发生了变化而已。既然我们明确了保存树形结构数据所需要的类型,下面我们通过代码来实现:

此次的代码相较于上篇文章中计算信息增益的变化主要有三个地方:

  1. """
  2. Author: Taoye
  3. 微信公众号: 玩世不恭的Coder
  4. Explain:创建训数据集
  5. """
  6. def establish_data():
  7. data = [[0, 0, 0, 0, 'N'], # 样本数据集相关信息,前面几项代表属性特征,最后一项表示是否放款
  8. [0, 0, 0, 1, 'N'],
  9. [0, 1, 0, 1, 'Y'],
  10. [0, 1, 1, 0, 'Y'],
  11. [0, 0, 0, 0, 'N'],
  12. [1, 0, 0, 0, 'N'],
  13. [1, 0, 0, 1, 'N'],
  14. [1, 1, 1, 1, 'Y'],
  15. [1, 0, 1, 2, 'Y'],
  16. [1, 0, 1, 2, 'Y'],
  17. [2, 0, 1, 2, 'Y'],
  18. [2, 0, 1, 1, 'Y'],
  19. [2, 1, 0, 1, 'Y'],
  20. [2, 1, 0, 2, 'Y'],
  21. [2, 0, 0, 0, 'N']]
  22. labels = ["年纪", "工作", "房子", "信用"]
  23. return np.array(data), labels
  1. """
  2. Author: Taoye
  3. 微信公众号: 玩世不恭的Coder
  4. Explain:找出对应属性特征值的样本,比如找出所有年纪为青年的样本数据集
  5. """
  6. def handle_data(data, axis, value):
  7. result_data = list()
  8. for item in data:
  9. if item[axis] == value:
  10. reduced_data = item[: axis].tolist()
  11. reduced_data.extend(item[axis + 1:])
  12. result_data.append(reduced_data)
  13. return result_data
  1. """
  2. Author: Taoye
  3. 微信公众号: 玩世不恭的Coder
  4. Explain:创建决策树
  5. """
  6. def establish_decision_tree(data, labels, feat_labels):
  7. cat_list = [item[-1] for item in data]
  8. if (cat_list.count(cat_list[0]) == len(cat_list)): return cat_list[0] # 数据集中的类别只有一种
  9. best_feature_index = calc_information_gain(data) # 通过信息增益优先选取最好的属性特征
  10. best_label = labels[best_feature_index] # 属性特征对应的标签内容
  11. # feat_labels表示已选取的属性;新建一个决策树节点;将属性标签列表中删除已选取的属性
  12. feat_labels.append(best_label); decision_tree = {best_label: dict()}; del(labels[best_feature_index])
  13. feature_values = [item[best_feature_index] for item in data]
  14. unique_values = set(feature_values) # 获取最优属性对应值的set集合
  15. for value in unique_values:
  16. sub_label = labels[:]
  17. decision_tree[best_label][value] = establish_decision_tree(np.array(handle_data(data, best_feature_index, value)), sub_label, feat_labels)
  18. return decision_tree

该部分的完整代码如下所示:

  1. import numpy as np
  2. import pandas as pd
  3. np.__version__
  4. pd.__version__
  5. """
  6. Author: Taoye
  7. 微信公众号: 玩世不恭的Coder
  8. Explain:创建训数据集
  9. """
  10. def establish_data():
  11. data = [[0, 0, 0, 0, 'N'], # 样本数据集相关信息,前面几项代表属性特征,最后一项表示是否放款
  12. [0, 0, 0, 1, 'N'],
  13. [0, 1, 0, 1, 'Y'],
  14. [0, 1, 1, 0, 'Y'],
  15. [0, 0, 0, 0, 'N'],
  16. [1, 0, 0, 0, 'N'],
  17. [1, 0, 0, 1, 'N'],
  18. [1, 1, 1, 1, 'Y'],
  19. [1, 0, 1, 2, 'Y'],
  20. [1, 0, 1, 2, 'Y'],
  21. [2, 0, 1, 2, 'Y'],
  22. [2, 0, 1, 1, 'Y'],
  23. [2, 1, 0, 1, 'Y'],
  24. [2, 1, 0, 2, 'Y'],
  25. [2, 0, 0, 0, 'N']]
  26. labels = ["年纪", "工作", "房子", "信用"]
  27. return np.array(data), labels
  28. """
  29. Author: Taoye
  30. 微信公众号: 玩世不恭的Coder
  31. Explain:计算信息熵
  32. """
  33. def calc_information_entropy(data):
  34. data_number, _ = data.shape
  35. information_entropy = 0
  36. for item in pd.DataFrame(data).groupby(_ - 1):
  37. proportion = item[1].shape[0] / data_number
  38. information_entropy += - proportion * np.log2(proportion)
  39. return information_entropy
  40. """
  41. Author: Taoye
  42. 微信公众号: 玩世不恭的Coder
  43. Explain:找出对应属性特征值的样本,比如找出所有年纪为青年的样本数据集
  44. """
  45. def handle_data(data, axis, value):
  46. result_data = list()
  47. for item in data:
  48. if item[axis] == value:
  49. reduced_data = item[: axis].tolist()
  50. reduced_data.extend(item[axis + 1:])
  51. result_data.append(reduced_data)
  52. return result_data
  53. """
  54. Author: Taoye
  55. 微信公众号: 玩世不恭的Coder
  56. Explain:计算最大的信息增益,并输出其所对应的特征索引
  57. """
  58. def calc_information_gain(data):
  59. feature_number = data.shape[1] - 1 # 属性特征的数量
  60. base_entropy = calc_information_entropy(data) # 计算总体数据集的信息熵
  61. max_information_gain, best_feature = 0.0, -1 # 初始化最大信息增益和对应的特征索引
  62. for index in range(feature_number):
  63. feat_list = [item[index] for item in data]
  64. feat_set = set(feat_list)
  65. new_entropy = 0.0
  66. for set_item in feat_set: # 计算属性特征划分后的信息增益
  67. sub_data = handle_data(data, index, set_item)
  68. proportion = len(sub_data) / float(data.shape[0]) # 计算子集的比例
  69. new_entropy += proportion * calc_information_entropy(np.array(sub_data))
  70. temp_information_gain = base_entropy - new_entropy # 计算信息增益
  71. print("第%d个属性特征所对应的的增益为%.3f" % (index + 1, temp_information_gain)) # 输出每个特征的信息增益
  72. if (temp_information_gain > max_information_gain):
  73. max_information_gain, best_feature = temp_information_gain, index # 更新信息增益,确定的最大的信息增益对应的索引
  74. return best_feature
  75. """
  76. Author: Taoye
  77. 微信公众号: 玩世不恭的Coder
  78. Explain:创建决策树
  79. """
  80. def establish_decision_tree(data, labels, feat_labels):
  81. cat_list = [item[-1] for item in data]
  82. if (cat_list.count(cat_list[0]) == len(cat_list)): return cat_list[0] # 数据集中的类别只有一种
  83. best_feature_index = calc_information_gain(data) # 通过信息增益优先选取最好的属性特征
  84. best_label = labels[best_feature_index] # 属性特征对应的标签内容
  85. # feat_labels表示已选取的属性;新建一个决策树节点;将属性标签列表中删除已选取的属性
  86. feat_labels.append(best_label); decision_tree = {best_label: dict()}; del(labels[best_feature_index])
  87. feature_values = [item[best_feature_index] for item in data]
  88. unique_values = set(feature_values) # 获取最优属性对应值的set集合
  89. for value in unique_values:
  90. sub_label = labels[:]
  91. decision_tree[best_label][value] = establish_decision_tree(np.array(handle_data(data, best_feature_index, value)), sub_label, feat_labels)
  92. return decision_tree
  93. if __name__ == "__main__":
  94. data, labels = establish_data()
  95. print(establish_decision_tree(data, labels, list()))

运行结果:

  1. {'房子': {'1': 'Y', '0': {'工作': {'1': 'Y', '0': 'N'}}}}

可见,代码运行的结果与我们手动创建的决策树如出一辙,完美,哈哈哈~~~

可是,如上决策树的显示未免有点太不亲民了,生成的决策树比较简单那还好点,假如我们生成的决策树比较复杂,那通过Json格式的数据来输出决策树就有点懵了。

对此,我们需要将决策树进行可视化,可视化主要使用到了Matplotlib包。关于Matplotlib的使用大家可以参考文档及其他资料,Taoye后期也会整理出一篇自己常用的接口。

  1. import numpy as np
  2. import pandas as pd
  3. """
  4. Author: Taoye
  5. 微信公众号: 玩世不恭的Coder
  6. Explain:创建训数据集
  7. """
  8. def establish_data():
  9. data = [[0, 0, 0, 0, 'N'], # 样本数据集相关信息,前面几项代表属性特征,最后一项表示是否放款
  10. [0, 0, 0, 1, 'N'],
  11. [0, 1, 0, 1, 'Y'],
  12. [0, 1, 1, 0, 'Y'],
  13. [0, 0, 0, 0, 'N'],
  14. [1, 0, 0, 0, 'N'],
  15. [1, 0, 0, 1, 'N'],
  16. [1, 1, 1, 1, 'Y'],
  17. [1, 0, 1, 2, 'Y'],
  18. [1, 0, 1, 2, 'Y'],
  19. [2, 0, 1, 2, 'Y'],
  20. [2, 0, 1, 1, 'Y'],
  21. [2, 1, 0, 1, 'Y'],
  22. [2, 1, 0, 2, 'Y'],
  23. [2, 0, 0, 0, 'N']]
  24. labels = ["年纪", "工作", "房子", "信用"]
  25. return np.array(data), labels
  26. """
  27. Author: Taoye
  28. 微信公众号: 玩世不恭的Coder
  29. Explain:计算信息熵
  30. """
  31. def calc_information_entropy(data):
  32. data_number, _ = data.shape
  33. information_entropy = 0
  34. for item in pd.DataFrame(data).groupby(_ - 1):
  35. proportion = item[1].shape[0] / data_number
  36. information_entropy += - proportion * np.log2(proportion)
  37. return information_entropy
  38. """
  39. Author: Taoye
  40. 微信公众号: 玩世不恭的Coder
  41. Explain:找出对应属性特征值的样本,比如找出所有年纪为青年的样本数据集
  42. """
  43. def handle_data(data, axis, value):
  44. result_data = list()
  45. for item in data:
  46. if item[axis] == value:
  47. reduced_data = item[: axis].tolist()
  48. reduced_data.extend(item[axis + 1:])
  49. result_data.append(reduced_data)
  50. return result_data
  51. """
  52. Author: Taoye
  53. 微信公众号: 玩世不恭的Coder
  54. Explain:计算最大的信息增益,并输出其所对应的特征索引
  55. """
  56. def calc_information_gain(data):
  57. feature_number = data.shape[1] - 1 # 属性特征的数量
  58. base_entropy = calc_information_entropy(data) # 计算总体数据集的信息熵
  59. max_information_gain, best_feature = 0.0, -1 # 初始化最大信息增益和对应的特征索引
  60. for index in range(feature_number):
  61. feat_list = [item[index] for item in data]
  62. feat_set = set(feat_list)
  63. new_entropy = 0.0
  64. for set_item in feat_set: # 计算属性特征划分后的信息增益
  65. sub_data = handle_data(data, index, set_item)
  66. proportion = len(sub_data) / float(data.shape[0]) # 计算子集的比例
  67. new_entropy += proportion * calc_information_entropy(np.array(sub_data))
  68. temp_information_gain = base_entropy - new_entropy # 计算信息增益
  69. print("第%d个属性特征所对应的的增益为%.3f" % (index + 1, temp_information_gain)) # 输出每个特征的信息增益
  70. if (temp_information_gain > max_information_gain):
  71. max_information_gain, best_feature = temp_information_gain, index # 更新信息增益,确定的最大的信息增益对应的索引
  72. return best_feature
  73. """
  74. Author: Taoye
  75. 微信公众号: 玩世不恭的Coder
  76. Explain:创建决策树
  77. """
  78. def establish_decision_tree(data, labels, feat_labels):
  79. cat_list = [item[-1] for item in data]
  80. if (cat_list.count(cat_list[0]) == len(cat_list)): return cat_list[0] # 数据集中的类别只有一种
  81. best_feature_index = calc_information_gain(data) # 通过信息增益优先选取最好的属性特征
  82. best_label = labels[best_feature_index] # 属性特征对应的标签内容
  83. # feat_labels表示已选取的属性;新建一个决策树节点;将属性标签列表中删除已选取的属性
  84. feat_labels.append(best_label); decision_tree = {best_label: dict()}; del(labels[best_feature_index])
  85. feature_values = [item[best_feature_index] for item in data]
  86. unique_values = set(feature_values) # 获取最优属性对应值的set集合
  87. for value in unique_values:
  88. sub_label = labels[:]
  89. decision_tree[best_label][value] = establish_decision_tree(np.array(handle_data(data, best_feature_index, value)), sub_label, feat_labels)
  90. return decision_tree
  91. """
  92. Author: Taoye
  93. 微信公众号: 玩世不恭的Coder
  94. Explain:统计决策树当中的叶子节点数目,以及决策树的深度
  95. """
  96. def get_leaf_number_and_tree_depth(decision_tree):
  97. leaf_number, first_key, tree_depth = 0, next(iter(decision_tree)), 0; second_dict = decision_tree[first_key]
  98. for key in second_dict.keys():
  99. if type(second_dict.get(key)).__name__ == "dict":
  100. temp_number, temp_depth = get_leaf_number_and_tree_depth(second_dict[key])
  101. leaf_number, curr_depth = leaf_number + temp_number, 1 + temp_depth
  102. else: leaf_number += 1; curr_depth = 1
  103. if curr_depth > tree_depth: tree_depth = curr_depth
  104. return leaf_number, tree_depth
  105. from matplotlib.font_manager import FontProperties
  106. """
  107. Author: Taoye
  108. 微信公众号: 玩世不恭的Coder
  109. Explain:绘制节点
  110. """
  111. def plot_node(node_text, center_pt, parent_pt, node_type):
  112. arrow_args = dict(arrowstyle = "<-")
  113. font = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=14) # 设置字体
  114. create_plot.ax1.annotate(node_text, xy=parent_pt, xycoords='axes fraction',
  115. xytext=center_pt, textcoords='axes fraction',
  116. va="center", ha="center", bbox=node_type, arrowprops=arrow_args, FontProperties=font)
  117. """
  118. Author: Taoye
  119. 微信公众号: 玩世不恭的Coder
  120. Explain:标注有向边的值
  121. """
  122. def tag_text(cntr_pt, parent_pt, node_text):
  123. x_mid = (parent_pt[0] - cntr_pt[0]) / 2.0 + cntr_pt[0]
  124. y_mid = (parent_pt[1] - cntr_pt[1]) / 2.0 + cntr_pt[1]
  125. create_plot.ax1.text(x_mid, y_mid, node_text, va="center", ha="center", rotation=30)
  126. """
  127. Author: Taoye
  128. 微信公众号: 玩世不恭的Coder
  129. Explain:绘制决策树
  130. """
  131. def plot_tree(decision_tree, parent_pt, node_text):
  132. decision_node = dict(boxstyle="sawtooth", fc="0.8")
  133. leaf_node = dict(boxstyle = "round4", fc = "0.8")
  134. leaf_number, tree_depth = get_leaf_number_and_tree_depth(decision_tree)
  135. first_key = next(iter(decision_tree))
  136. cntr_pt = (plot_tree.xOff + (1.0 + float(leaf_number)) / 2.0 / plot_tree.totalW, plot_tree.yOff)
  137. tag_text(cntr_pt, parent_pt, node_text); plot_node(first_key, cntr_pt, parent_pt, decision_node)
  138. second_dict = decision_tree[first_key]
  139. plot_tree.yOff = plot_tree.yOff - 1.0 / plot_tree.totalD
  140. for key in second_dict.keys():
  141. if type(second_dict[key]).__name__ == 'dict': plot_tree(second_dict[key], cntr_pt, str(key))
  142. else:
  143. plot_tree.xOff = plot_tree.xOff + 1.0 / plot_tree.totalW
  144. plot_node(second_dict[key], (plot_tree.xOff, plot_tree.yOff), cntr_pt, leaf_node)
  145. tag_text((plot_tree.xOff, plot_tree.yOff), cntr_pt, str(key))
  146. plot_tree.yOff = plot_tree.yOff + 1.0 / plot_tree.totalD
  147. from matplotlib import pyplot as plt
  148. """
  149. Author: Taoye
  150. 微信公众号: 玩世不恭的Coder
  151. Explain:创建决策树
  152. """
  153. def create_plot(in_tree):
  154. fig = plt.figure(1, facecolor = "white")
  155. fig.clf()
  156. axprops = dict(xticks = [], yticks = [])
  157. create_plot.ax1 = plt.subplot(111, frameon = False, **axprops)
  158. leaf_number, tree_depth = get_leaf_number_and_tree_depth(in_tree)
  159. plot_tree.totalW, plot_tree.totalD = float(leaf_number), float(tree_depth)
  160. plot_tree.xOff = -0.5 / plot_tree.totalW; plot_tree.yOff = 1.0
  161. plot_tree(in_tree, (0.5,1.0), '')
  162. plt.show()
  163. if __name__ == "__main__":
  164. data, labels = establish_data()
  165. decision_tree = establish_decision_tree(data, labels, list())
  166. print(decision_tree)
  167. print("决策树的叶子节点数和深度:", get_leaf_number_and_tree_depth(decision_tree))
  168. create_plot(decision_tree)

手动可视化决策树的结果如下所示:

实话实说,通过Matplotlib手动对决策树进行可视化,对于之前没什么经验的码手来讲确实有点不友好。上述代码能看懂就行,没必要死揪着不放,后面的话会介绍通过Graphviz来绘制决策树。这里对上方的代码做个简短的说明:

二、基于已经构建好的决策树进行分类预测

依靠训练数据构造了决策树之后,我们既可以通过该决策树模型应用于实际数据来进行分类。在对数据进行分类时,需要使用决策树以及用于构造决策树的标签向量;然后,程序比较测试数据与决策树上的数值,递归执行该过程直到进入到叶子节点;最后将测试数据定义为叶子节点所属的类型。——《机器学习实战》

在已经获取到决策树模型的前提下对测试数据进行分类还是挺好理解的

对此,我们定义一个classify方法来进行分类:

  1. """
  2. Author: Taoye
  3. 微信公众号: 玩世不恭的Coder
  4. Explain:通过决策树模型对测试数据进行分类
  5. """
  6. def classify(decision_tree, best_feature_labels, test_data):
  7. first_node = next(iter(decision_tree))
  8. second_dict = decision_tree[first_node]
  9. feat_index = best_feature_labels.index(first_node)
  10. for key in second_dict.keys():
  11. if int(test_data[feat_index]) == int(key):
  12. if type(second_dict[key]).__name__ == "dict": # 为字典说明还没到叶子节点
  13. result_label = classify(second_dict[key], best_feature_labels, test_data)
  14. else: result_label = second_dict[key]
  15. return result_label

我们分别对四个数据样本进行测试,样本分别是(有房子,没工作),(有房子,有工作),(没房子,没工作),(没房子,有工作),使用列表表示分别是[1, 0], [1, 1], [0, 0], [0, 1],运行结果如下:

可见,四组数据都能够分类成功。

三、构建好的决策树模型应当如何保存和读取?

构建好决策树之后,我们要想保存该模型就可以通过pickle模块来进行。

保存好模型之后,下次使用该模型就不需要的再次训练了,只需要加载模型即可。保存和加载模型的示例代码如下(挺简单的,就不多说了):

  1. import pickle
  2. with open("DecisionTreeModel.txt", "wb") as f:
  3. pickle.dump(decision_tree, f) # 保存决策树模型
  4. f = open("DecisionTreeModel.txt", "rb")
  5. decision_tree = pickle.load(f) # 加载决策树模型

四、通过鸢尾花(Iris)数据集,使用Sklearn构建决策树

现在我们通过sklearn来实现一个小案例,数据集采用的是机器学习中比较常用的鸢尾花(Iris)数据集。更多其他关于决策树分类的案例,大家可以去sklearn.tree.DecisionTreeClassifier的文档中学习:https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html

在sklearn中实现决策树分类主要用到的接口是sklearn.tree.DecisionTreeClassifier,这个主要是通过数据样本集构建一个决策树模型。此外,如果我们要想将决策树可视化,还需要使用到export_graphviz。当然了,在sklearn.tree`下还有其他接口可供大家调用,这里不做的过多介绍了,读者可自行学习。

关于sklearn.tree.DecisionTreeClassifier的使用,可参考:https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html。其中内置了很多的参数,这里主要记录8个参数,也方便后期回顾,其他的有机会用到再查找资料:

此外,在DecisionTreeClassifier下也有很多方法可供调用,详情可参考文档进行使用,如下:

接下来,我们就用sklearn来对鸢尾花数据集进行分类吧。参考资料:https://scikit-learn.org/stable/auto_examples/tree/plot_iris_dtc.html#sphx-glr-auto-examples-tree-plot-iris-dtc-py

构建决策树本身的代码并不难,主要在于可视化,其中涉及到了Matplotlib的不少操作,以增强可视化效果,完整代码如下:

  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. from sklearn.datasets import load_iris
  4. from sklearn.tree import DecisionTreeClassifier, plot_tree
  5. class IrisDecisionTree:
  6. """
  7. Explain:属性的初始化
  8. Parameters:
  9. n_classes: 鸢尾花的类别数
  10. plot_colors: 不同类别花的颜色
  11. plot_step: meshgrid网格的步长
  12. """
  13. def __init__(self, n_classes, plot_colors, plot_step):
  14. self.n_classes = n_classes
  15. self.plot_colors = plot_colors
  16. self.plot_step = plot_step
  17. """
  18. Explain: 通过load_iris构建数据集
  19. """
  20. def establish_data(self):
  21. iris_info = load_iris()
  22. return iris_info.data, iris_info.target, iris_info.feature_names, iris_info.target_names
  23. """
  24. Explain:分类的可视化
  25. """
  26. def show_result(self, x_data, y_label, feature_names, target_names):
  27. # 选取两个属性来构建决策树,以方便可视化,其中列表内部元素代表属性对应的索引
  28. for index, pair in enumerate([[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]):
  29. sub_x_data, sub_y_label = x_data[:, pair], y_label
  30. clf = DecisionTreeClassifier().fit(sub_x_data, sub_y_label) # 选取两个属性构建决策树
  31. plt.subplot(2, 3, index + 1)
  32. x_min, x_max = sub_x_data[:, 0].min() - 1, sub_x_data[:, 0].max() + 1 # 第一个属性
  33. y_min, y_max = sub_x_data[:, 1].min() - 1, sub_x_data[:, 1].max() + 1 # 第二个属性
  34. xx, yy = np.meshgrid(np.arange(x_min, x_max, self.plot_step), np.arange(y_min, y_max, self.plot_step))
  35. plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5)
  36. Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape) # 预测meshgrid内部每个元素的分类
  37. cs = plt.contourf(xx, yy, Z, cmap = plt.cm.RdYlBu) # 绘制带有颜色的网格图
  38. plt.xlabel(feature_names[pair[0]]); plt.ylabel(feature_names[pair[1]]) # 标注坐标轴标签
  39. for i, color in zip(range(self.n_classes), self.plot_colors):
  40. idx = np.where(sub_y_label == i)
  41. plt.scatter(sub_x_data[idx, 0], sub_x_data[idx, 1], c=color, label=target_names[i],
  42. cmap=plt.cm.RdYlBu, edgecolor='black', s=15) # 绘制数据样本集的散点图
  43. from matplotlib.font_manager import FontProperties
  44. font = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=14) # 定义中文字体
  45. plt.suptitle("通过决策树对鸢尾花数据进行可视化", fontproperties=font)
  46. plt.legend(loc='lower right', borderpad=0, handletextpad=0)
  47. plt.axis("tight")
  48. plt.figure()
  49. clf = DecisionTreeClassifier().fit(x_data, y_label) # 针对鸢尾花数据集的多重属性来构建决策树
  50. plot_tree(clf, filled=True)
  51. plt.show()
  52. if __name__ == "__main__":
  53. iris_decision_tree = IrisDecisionTree(3, "ryb", 0.02)
  54. x_data, y_label, feature_names, target_names = iris_decision_tree.establish_data()
  55. iris_decision_tree.show_result(x_data, y_label, feature_names, target_names)

运行结果如下所示:

通过可视化结果,我们可以发现主要有两个结果,我们分别对其进行说明下:于鸢尾花数据集来讲,总共有四种属性特征以及三种标签结果。为了方便可视化,第一张图只选取了两个属性来构建决策树,四种属性,选两个,很简单,学过排列组合的都应该知道有种可能,所第一张图中的每张子图分别对应一种可能。且颜色不同代表不同的分类,假如数据集颜色与网格内部颜色一致,则说明分类正确。因此,从直观上来看,选取sepal length和petal length这两种属性构建的决策树相对较好。而第二张图是针对数据集中的所有属性特征来构建决策树。具体的结果可自行运行上方代码查看(由于设置了font字体,所以上方代码需在windows下运行)

我们可以发现,上面代码可视化决策树的时候采用的是sklearn.tree.plot_tree,前面我们在讲解通过Matplotlib绘制决策树的时候也有说到,使用graphviz亦可可视化决策树,下面我们不妨来看看吧!

graphviz不能采用pip进行安装,采用anaconda安装的时候也会很慢,甚至多次尝试都可能安装失败,前几天帮同学安装就出现这种情况(windows下是这样的,linux环境下会很方便),所以这里我们采用直接通过whl文件来安装。

建议:对于使用Python有过一段时间的Pyer来讲,都会经常安装一些第三方模块,有些可以直接通过pip或者conda完美的解决,而有些在安装的过程中会遇到各种不明所以的错误。所以,对于在安装过程中遇到错误的读者不妨尝试通过whl文件进行安装,whl目标地址:https://www.lfd.uci.edu/~gohlke/pythonlibs/#wordcloud,其中整合了各种Python模块的whl文件。

打开上述url地址 --> ctrl + f搜索graphviz --> 下载需要的graphviz安装文件

在本地目标路径中执行安装即可:

  1. pip install graphviz0.15py3noneany.whl

此外对于windows来将 ,还需要前往官网安装graphviz的exe文件,然后将bin目录添加到环境变量即可。exe文件的下载地址:https://graphviz.org/download/

如果是Linux用户,那就比较方便了,直接通过命令安装即可:

  1. $ sudo apt install graphviz # Ubuntu
  2. $ sudo apt install graphviz # Debian

至此,Graphviz就已经安装好了。我们通过它来实现决策树的可视化吧,在IrisDecisionTree下添加如下show_result_by_graphviz方法:

  1. """
  2. Explain:通过graphviz实现决策树的可视化
  3. """
  4. def show_result_by_graphviz(self, x_data, y_label):
  5. clf = DecisionTreeClassifier().fit(x_data, y_label)
  6. iris_dot_data = tree.export_graphviz(clf, out_file=None,
  7. feature_names=iris.feature_names,
  8. class_names=iris.target_names,
  9. filled=True, rounded=True,
  10. special_characters=True)
  11. import graphviz
  12. graph = graphviz.Source(iris_dot_data); graph.render("iris")

运行之后会在当前目录下生成一个pdf文件,其中就是可视化之后的决策树。注意:以上只是实现简单的鸢尾花的决策树分类,读者可通过调解DecisionTreeClassifier的参数构建不同的决策树,以此来判别各个决策树的优劣。

以上就是本篇全部内容了,决策树的相关内容暂时就更新到这了,其他内容像过拟合、剪枝等等以后有时间再更新,下期的机器学习系列文章就是肝SVM的非线性模型了。

就不唠嗑了~~~

我是Taoye,爱专研,爱分享,热衷于各种技术,学习之余喜欢下象棋、听音乐、聊动漫,希望借此一亩三分地记录自己的成长过程以及生活点滴,也希望能结实更多志同道合的圈内朋友,更多内容欢迎来访微信公主号:玩世不恭的Coder

参考资料:

[1] 《机器学习实战》:Peter Harrington 人民邮电出版社
[2] 《统计学习方法》:李航 第二版 清华大学出版社
[3] 《机器学习》:周志华 清华大学出版社
[4] Python Extension Packages:https://www.lfd.uci.edu/~gohlke/pythonlibs/#wordcloud
[5] sklearn.tree.DecisionTreeClassifier:https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html
[6] Graphviz官网:https://graphviz.org/

推荐阅读

《Machine Learning in Action》—— Taoye给你讲讲决策树到底是支什么“鬼”
《Machine Learning in Action》—— 剖析支持向量机,优化SMO
《Machine Learning in Action》—— 剖析支持向量机,单手狂撕线性SVM
print( "Hello,NumPy!" )
干啥啥不行,吃饭第一名
Taoye渗透到一家黑平台总部,背后的真相细思极恐
《大话数据库》-SQL语句执行时,底层究竟做了什么小动作?
那些年,我们玩过的Git,真香
基于Ubuntu+Python+Tensorflow+Jupyter notebook搭建深度学习环境
网络爬虫之页面花式解析
手握手带你了解Docker容器技术

添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注