@spiritnotes
2016-02-25T17:33:41.000000Z
字数 679
阅读 1768
机器学习实践
def Visualization(data, feature_indexs = None):features = data['data']names = data['feature_names']target = data['target']if feature_indexs is None:feature_indexs = list(range(len(names)))for i in feature_indexs:for j in feature_indexs:if j <= i:continueshow_plt(features, i, j, target, names)def show_plt(data, i, j, target, names):from matplotlib import pyplot as pltvis_types = ">ox"vis_colors = "rgb"target_set = set(target)for class_, marker, color in zip(target_set, vis_types, vis_colors):rows = (target == class_)plt.scatter(data[rows ,i],data[rows, j], marker=marker,c=color)if names:xlabel, ylabel = names[i], names[j]else:xlabel, ylabel = str(i), str(j)plt.xlabel(xlabel)plt.ylabel(ylabel)plt.autoscale(tight=True)plt.show()
