[关闭]
@NumberFairy 2017-10-10T08:03:36.000000Z 字数 1944 阅读 942

利用mnist data set训练模型,识别手写图片

总结


1、利用mnist手写数据集训练model;最简单的model如下:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)
#获取到的数据:training set, test set, verify set.
import  tensorflow as tf
x = tf.placeholder(tf.float32, [None, 784])
w = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
# y对应的计算值
y = tf.nn.softmax(tf.matmul(x,w)+b)
# y_对应的是实际值
y_ = tf.placeholder(tf.float32, [None, 10])
# 创建交叉熵函数
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
# 运行迭代
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
for _ in range(100):
    #获取training set,每次获取100个samples
    batch_xs, batch_ys = mnist.train.next_batch(100)
    # 喂数据, 训练
    sess.run(train_step, feed_dict={x: batch_xs, y_:batch_ys})  
# 上面已经把model训练好了,下面开始测试(这是针对mnist测试集合而言,得到的结果是准确度的均值)
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

2、手写一张图片,或者是PS 一张也可以。我这里手写一张数字“3”,然后,我用matlab进行处理成28×28的尺寸,同时还要把图片转化微黑白的。粘贴出matlab代码:

i = imread('C:\Users\Administrator\Desktop\3.jpg')
img = rgb2gray(i)
imwrite(img,'C:\Users\Administrator\Desktop\3.bmp')

其实在这之前,虽好把图片的像素调整到28*28,我直接用Windows上的画图工具调整了一下。到此为止,图片的处理结束。(原始图片的清晰度也会影响到最后的评估结果)

3、读取图片信息,得到图片的二维矩阵;代码如下:

from PIL import Image
import numpy as np
img = Image.open('MNIST_data/3.bmp')
data = img.getdata()
# 得到图片的二维矩阵28×28
data = np.matrix(data)
test_res = tf.argmax(y,1)
real_res = tf.argmax(y_,1)
# data.reshape((1,784)) 把矩阵转成1*784的格式
print('测试值:%s' % (sess.run(test_res, feed_dict={x: data.reshape((1,784)), y_: [[ 0. , 0. , 0. , 0. , 0. , 0.,  0. , 1. , 0. , 0. ]]})))
print('真实值:%s' % (sess.run(real_res, feed_dict={x: mnist.test.images[0].reshape((1,784)), y_: [[ 0. , 0. , 0. , 1. , 0. , 0.,  0. , 1. , 0. , 0. ]]})))

总结:上面的代码从上到下是完整的代码,可直接运行。其实上面的准确度很低,只有90%左右;我手写的“3”竟然被识别成了“5”,无奈。。。这上面主要阐述手写图片的过程,如果要改进准确度的话,也是比较简单的。可以尝试用CNN来构建模型,经测试,CNN构建出来的模型,准确度达到97%左右。

添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注