@spiritnotes
2016-02-25T17:33:41.000000Z
字数 679
阅读 1654
机器学习实践
def Visualization(data, feature_indexs = None):
features = data['data']
names = data['feature_names']
target = data['target']
if feature_indexs is None:
feature_indexs = list(range(len(names)))
for i in feature_indexs:
for j in feature_indexs:
if j <= i:
continue
show_plt(features, i, j, target, names)
def show_plt(data, i, j, target, names):
from matplotlib import pyplot as plt
vis_types = ">ox"
vis_colors = "rgb"
target_set = set(target)
for class_, marker, color in zip(target_set, vis_types, vis_colors):
rows = (target == class_)
plt.scatter(data[rows ,i],data[rows, j], marker=marker,c=color)
if names:
xlabel, ylabel = names[i], names[j]
else:
xlabel, ylabel = str(i), str(j)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.autoscale(tight=True)
plt.show()