[关闭]
@qidiandasheng 2018-06-27T09:33:28.000000Z 字数 5848 阅读 1325

Tensorflow MNIST入门

机器学习


代码

训练模型并导出.pb文件:

  1. import tensorflow as tf
  2. from tensorflow.examples.tutorials.mnist import input_data
  3. #import data
  4. mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
  5. sess = tf.InteractiveSession()
  6. # Create the model
  7. x = tf.placeholder(tf.float32, [None, 784])
  8. y_ = tf.placeholder(tf.float32, [None, 10])
  9. W = tf.Variable(tf.zeros([784, 10]))
  10. b = tf.Variable(tf.zeros([10]))
  11. y = tf.nn.softmax(tf.matmul(x, W) + b)
  12. def weight_variable(shape):
  13. initial = tf.truncated_normal(shape, stddev=0.1)
  14. return tf.Variable(initial)
  15. def bias_variable(shape):
  16. initial = tf.constant(0.1, shape=shape)
  17. return tf.Variable(initial)
  18. def conv2d(x, W):
  19. return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
  20. def max_pool_2x2(x):
  21. return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
  22. strides=[1, 2, 2, 1], padding='SAME')
  23. W_conv1 = weight_variable([5, 5, 1, 32])
  24. b_conv1 = bias_variable([32])
  25. x_image = tf.reshape(x, [-1,28,28,1])
  26. h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
  27. h_pool1 = max_pool_2x2(h_conv1)
  28. W_conv2 = weight_variable([5, 5, 32, 64])
  29. b_conv2 = bias_variable([64])
  30. h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
  31. h_pool2 = max_pool_2x2(h_conv2)
  32. W_fc1 = weight_variable([7 * 7 * 64, 1024])
  33. b_fc1 = bias_variable([1024])
  34. h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
  35. h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
  36. keep_prob = tf.placeholder(tf.float32)
  37. h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
  38. W_fc2 = weight_variable([1024, 10])
  39. b_fc2 = bias_variable([10])
  40. y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
  41. # Define loss and optimizer
  42. cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
  43. train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
  44. correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
  45. accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  46. """
  47. Train the model and save the model to disk as a model2.ckpt file
  48. file is stored in the same directory as this python script is started
  49. Based on the documentatoin at
  50. https://www.tensorflow.org/versions/master/how_tos/variables/index.html
  51. """
  52. saver = tf.train.Saver()
  53. sess.run(tf.global_variables_initializer())
  54. #with tf.Session() as sess:
  55. #sess.run(init_op)
  56. for i in range(20000):
  57. batch = mnist.train.next_batch(50)
  58. if i%100 == 0:
  59. train_accuracy = accuracy.eval(feed_dict={
  60. x:batch[0], y_: batch[1], keep_prob: 1.0})
  61. print("step %d, training accuracy %g"%(i, train_accuracy))
  62. train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
  63. save_path = saver.save(sess, "./model2.ckpt")
  64. tf.train.write_graph(sess.graph_def, '', 'graph.pb')
  65. print ("Model saved in file: ", save_path)
  66. print("test accuracy %g"%accuracy.eval(feed_dict={
  67. x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

识别图片:

  1. import sys
  2. import tensorflow as tf
  3. from PIL import Image, ImageFilter
  4. def predictint(imvalue):
  5. """
  6. This function returns the predicted integer.
  7. The imput is the pixel values from the imageprepare() function.
  8. """
  9. # Define the model (same as when creating the model file)
  10. x = tf.placeholder(tf.float32, [None, 784])
  11. W = tf.Variable(tf.zeros([784, 10]))
  12. b = tf.Variable(tf.zeros([10]))
  13. def weight_variable(shape):
  14. initial = tf.truncated_normal(shape, stddev=0.1)
  15. return tf.Variable(initial)
  16. def bias_variable(shape):
  17. initial = tf.constant(0.1, shape=shape)
  18. return tf.Variable(initial)
  19. def conv2d(x, W):
  20. return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
  21. def max_pool_2x2(x):
  22. return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
  23. W_conv1 = weight_variable([5, 5, 1, 32])
  24. b_conv1 = bias_variable([32])
  25. x_image = tf.reshape(x, [-1,28,28,1])
  26. h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
  27. h_pool1 = max_pool_2x2(h_conv1)
  28. W_conv2 = weight_variable([5, 5, 32, 64])
  29. b_conv2 = bias_variable([64])
  30. h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
  31. h_pool2 = max_pool_2x2(h_conv2)
  32. W_fc1 = weight_variable([7 * 7 * 64, 1024])
  33. b_fc1 = bias_variable([1024])
  34. h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
  35. h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
  36. keep_prob = tf.placeholder(tf.float32)
  37. h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
  38. W_fc2 = weight_variable([1024, 10])
  39. b_fc2 = bias_variable([10])
  40. y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
  41. init_op = tf.global_variables_initializer()
  42. saver = tf.train.Saver()
  43. """
  44. Load the model2.ckpt file
  45. file is stored in the same directory as this python script is started
  46. Use the model to predict the integer. Integer is returend as list.
  47. Based on the documentatoin at
  48. https://www.tensorflow.org/versions/master/how_tos/variables/index.html
  49. """
  50. with tf.Session() as sess:
  51. sess.run(init_op)
  52. saver.restore(sess, "model2.ckpt")
  53. #print ("Model restored.")
  54. prediction=tf.argmax(y_conv,1)
  55. return prediction.eval(feed_dict={x: [imvalue],keep_prob: 1.0}, session=sess)
  56. def imageprepare(argv):
  57. """
  58. This function returns the pixel values.
  59. The imput is a png file location.
  60. """
  61. im = Image.open(argv).convert('L')
  62. width = float(im.size[0])
  63. height = float(im.size[1])
  64. newImage = Image.new('L', (28, 28), (255)) #creates white canvas of 28x28 pixels
  65. if width > height: #check which dimension is bigger
  66. #Width is bigger. Width becomes 20 pixels.
  67. nheight = int(round((20.0/width*height),0)) #resize height according to ratio width
  68. if (nheigth == 0): #rare case but minimum is 1 pixel
  69. nheigth = 1
  70. # resize and sharpen
  71. img = im.resize((20,nheight), Image.ANTIALIAS).filter(ImageFilter.SHARPEN)
  72. wtop = int(round(((28 - nheight)/2),0)) #caculate horizontal pozition
  73. newImage.paste(img, (4, wtop)) #paste resized image on white canvas
  74. else:
  75. #Height is bigger. Heigth becomes 20 pixels.
  76. nwidth = int(round((20.0/height*width),0)) #resize width according to ratio height
  77. if (nwidth == 0): #rare case but minimum is 1 pixel
  78. nwidth = 1
  79. # resize and sharpen
  80. img = im.resize((nwidth,20), Image.ANTIALIAS).filter(ImageFilter.SHARPEN)
  81. wleft = int(round(((28 - nwidth)/2),0)) #caculate vertical pozition
  82. newImage.paste(img, (wleft, 4)) #paste resized image on white canvas
  83. #newImage.save("sample.png")
  84. tv = list(newImage.getdata()) #get pixel values
  85. #normalize pixels to 0 and 1. 0 is pure white, 1 is pure black.
  86. tva = [ (255-x)*1.0/255.0 for x in tv]
  87. return tva
  88. #print(tva)
  89. def main(argv):
  90. """
  91. Main function.
  92. """
  93. imvalue = imageprepare(argv)
  94. predint = predictint(imvalue)
  95. print (predint[0]) #first value in list
  96. if __name__ == "__main__":
  97. main(sys.argv[1])

参考

mnist 源码
mnist 源码解析
深度学习 - Tensorflow on iOS 入门 + MNIST
机器学习Tensorflow笔记1:Hello World到MNIST实验
机器学习Tensorflow笔记2:超详细剖析MNIST实验
MNIST机器学习入门
写给初学者的深度学习教程之 MNIST 数字识别
tensorflow-mnist-predict

添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注