[关闭]
@Pigmon 2017-05-28T00:30:27.000000Z 字数 6166 阅读 1126

MNIST 玩具

教案


说明

简单的很,只是做了个界面可以用鼠标画数字,给之前基于MNIST数据集训练好的模型来Predict,然后显示出识别的结果。
这东西毫无实用价值,看看就好。

MNIST.PNG-8.9kB

我做的比较简单,输入的图像的笔画是单像素单一颜色的,所以识别率并不高。下面链接里自带的模型,包括了我在这样的界面上输入的120个训练样本,但效果并不明显(输入太累手,不想再弄了。)

用这样的界面生成MNIST格式训练样本的程序之后我会整理下发出,因为包含MNIST格式说明的部分,所以文档内容会多点,不像这个没啥可说的。

下载链接:
链接:http://pan.baidu.com/s/1pLI2I1d 密码:7wdh

如果你机器上有 Python 3.5, PyQT5, Tensorflow,TFLearn以及它们的依赖项都有安装的话,解压就可以直接运行了。

我猜你们没有,如果只想看看程序的话:

界面程序

  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. MNIST 手写数字识别 Demo
  5. @Yuan Sheng
  6. 这就是一个 PyQT5 的界面程序,
  7. 目的是把你用鼠标画出来的数字给mnist_model.py去识别。
  8. 识别率很低,具体原因没你的事,留给我自己思考。
  9. 本程序受开源软件协议 @DWFYWDEC 保护。
  10. @DWFYWDEC:Do What the Fuck You Want to Do Except Cheating.
  11. FYI:
  12. @DWFYWDEC 协议是本程序作者基于 @DWFYWD 协议胡编乱造的(就在几秒钟之前)。
  13. @DWFYWD 协议是另外一个程序的作者胡编乱造的。
  14. FYI2:
  15. 虽然是胡编乱造的但你一旦违反本协议会遭到本人疯狂的报复。
  16. 我是个疯狂的人以至于我都不知道一旦你违反本协议我会怎么报复你。
  17. 不过这个程序P用都没有所以你不用担心。
  18. """
  19. import sys
  20. from PyQt5 import QtCore, QtGui, QtWidgets
  21. from PyQt5.QtCore import Qt, QRect
  22. from PyQt5.QtWidgets import QWidget, QApplication, QPushButton, QLCDNumber
  23. import numpy as np
  24. from mnist_model import *
  25. class Example(QWidget):
  26. def __init__(self):
  27. super().__init__()
  28. self.InitInputData()
  29. self.initUI()
  30. def initUI(self):
  31. self.setGeometry(300, 300, 640, 480)
  32. self.setWindowTitle(u'MNIST 演示')
  33. self.initInputWidget()
  34. self.show()
  35. def InitInputData(self):
  36. # 鼠标左键是不是按下了
  37. self.mouse_pressed = False
  38. # 目前识别的结果,为了在LCD控件中显示数字
  39. self.predicted_nbr = -1
  40. # 输入窗口中,每 scale*scale 个像素视作传给MNIST模型的一个像素
  41. self.scale = 10
  42. # 28*28的数组:输入窗口显示数据,
  43. # 以及传递给MNIST模型进行Predict
  44. self.pt_arr = np.zeros((28, 28))
  45. # 把输入控件的面积分成 28*28 个rect。
  46. # 每个rect的尺寸为 scale*scale
  47. # 作用是检测手写输入时需要绘制成黑色的部分,
  48. # 以及推算逻辑上的28*28的输入图像哪些像素是黑色的
  49. self.rect_arr = np.array([QRect(0, 0, self.scale, self.scale)] * 28 * 28).reshape((28, 28))
  50. for i in range(28):
  51. for j in range(28):
  52. rect = QRect(i * self.scale, j * self.scale, self.scale, self.scale)
  53. self.rect_arr[i][j] = rect
  54. # 读取模型
  55. self.model = MnistModel('models/mnist5/mnist5.tfl')
  56. def initInputWidget(self):
  57. self.widget = QtWidgets.QWidget(self)
  58. self.widget.setGeometry(QtCore.QRect(10, 10, 280, 280))
  59. # -> 下面这一大段就是为了让手写输入的widget有个白色的背景而已
  60. palette = QtGui.QPalette()
  61. brush = QtGui.QBrush(QtGui.QColor(255, 255, 255))
  62. brush.setStyle(QtCore.Qt.SolidPattern)
  63. palette.setBrush(QtGui.QPalette.Active, QtGui.QPalette.Base, brush)
  64. brush = QtGui.QBrush(QtGui.QColor(255, 255, 255))
  65. brush.setStyle(QtCore.Qt.SolidPattern)
  66. palette.setBrush(QtGui.QPalette.Active, QtGui.QPalette.Window, brush)
  67. brush = QtGui.QBrush(QtGui.QColor(255, 255, 255))
  68. brush.setStyle(QtCore.Qt.SolidPattern)
  69. palette.setBrush(QtGui.QPalette.Inactive, QtGui.QPalette.Base, brush)
  70. brush = QtGui.QBrush(QtGui.QColor(255, 255, 255))
  71. brush.setStyle(QtCore.Qt.SolidPattern)
  72. palette.setBrush(QtGui.QPalette.Inactive, QtGui.QPalette.Window, brush)
  73. brush = QtGui.QBrush(QtGui.QColor(255, 255, 255))
  74. brush.setStyle(QtCore.Qt.SolidPattern)
  75. palette.setBrush(QtGui.QPalette.Disabled, QtGui.QPalette.Base, brush)
  76. brush = QtGui.QBrush(QtGui.QColor(255, 255, 255))
  77. brush.setStyle(QtCore.Qt.SolidPattern)
  78. palette.setBrush(QtGui.QPalette.Disabled, QtGui.QPalette.Window, brush)
  79. self.widget.setPalette(palette)
  80. self.widget.setAutoFillBackground(True)
  81. # <- 到这为止
  82. self.widget.setObjectName("input")
  83. self.widget.paintEvent = self.paintEvent
  84. self.widget.mousePressEvent = self.inputMousePressed
  85. self.widget.mouseReleaseEvent = self.inputMouseReleased
  86. self.widget.mouseMoveEvent = self.inputMouseMove
  87. # LCD Number
  88. self.lcd = QLCDNumber(self)
  89. self.lcd.setGeometry(QRect(300, 10, 240, 100))
  90. self.lcd.display("")
  91. # 按钮
  92. self.btn1 = QPushButton(u"识别", self)
  93. self.btn1.move(10, 300)
  94. self.btn2 = QPushButton(u"清除", self)
  95. self.btn2.move(120, 300)
  96. # 按钮事件
  97. self.btn1.clicked.connect(self.btnPredictClicked)
  98. self.btn2.clicked.connect(self.btnClearClicked)
  99. def paintEvent(self, event):
  100. "OnPaint回调,把需要画成黑色的rect画黑"
  101. paint=QtGui.QPainter()
  102. paint.begin(self.widget)
  103. paint.setPen(QtCore.Qt.black)
  104. for i in range(28):
  105. for j in range(28):
  106. if self.pt_arr[i][j] > 0.5:
  107. paint.fillRect(self.rect_arr[i][j], Qt.black)
  108. paint.end()
  109. def keyPressEvent(self, e):
  110. if e.key() == Qt.Key_Escape:
  111. self.close()
  112. def inputMouseReleased(self, e):
  113. "鼠标左键是否弹起"
  114. if e.button() == Qt.LeftButton:
  115. self.mouse_pressed = False
  116. self.widget.repaint()
  117. def inputMousePressed(self, e):
  118. "鼠标左键是否按下"
  119. if e.button() == Qt.LeftButton:
  120. self.mouse_pressed = True
  121. else:
  122. self.mouse_pressed = False
  123. self.widget.repaint()
  124. def inputMouseMove(self, e):
  125. """
  126. 手写输入时,按下鼠标左键后的事件响应。
  127. 即:如果鼠标按下了,那么鼠标移动的过程中,
  128. 检测rect_arr中哪些rect需要被画成黑色
  129. """
  130. if self.mouse_pressed:
  131. for i in range(28):
  132. for j in range(28):
  133. rect = self.rect_arr[i][j]
  134. if rect.contains(e.pos()):
  135. self.pt_arr[i][j] = 1.0
  136. self.widget.repaint()
  137. break
  138. def btnPredictClicked(self):
  139. "识别按钮事件响应函数"
  140. arr = self.pt_arr.transpose().reshape((1, 28, 28, 1))
  141. result = self.model.predict(arr)[0]
  142. self.predicted_nbr = result.index(max(result))
  143. self.lcd.display(self.predicted_nbr)
  144. self.lcd.repaint()
  145. def btnClearClicked(self):
  146. "清除按钮事件响应函数"
  147. self.pt_arr = np.zeros((28, 28))
  148. self.widget.repaint()
  149. self.predicted_nbr = -1
  150. self.lcd.display("")
  151. if __name__ == '__main__':
  152. app = QApplication(sys.argv)
  153. ex = Example()
  154. sys.exit(app.exec_())

用训练好的卷积神经网络模型来 Predict 输入图像的程序(作为上面程序的头文件)

  1. """
  2. MNIST Predict
  3. @Yuan Sheng
  4. FYI:
  5. 这个程序就是TFLearn的例子简单的修改。
  6. 因为那个例子里没有声明什么吓人的东西,
  7. 所以我也不知道一旦你拿这个程序做些为非作歹的事情,
  8. TFLearn 和 Y. LeCun 会怎么报复你。
  9. """
  10. """ Convolutional Neural Network for MNIST dataset classification task.
  11. References:
  12. Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. "Gradient-based
  13. learning applied to document recognition." Proceedings of the IEEE,
  14. 86(11):2278-2324, November 1998.
  15. Links:
  16. [MNIST Dataset] http://yann.lecun.com/exdb/mnist/
  17. """
  18. from __future__ import division, print_function, absolute_import
  19. import tflearn
  20. from tflearn.layers.core import input_data, dropout, fully_connected
  21. from tflearn.layers.conv import conv_2d, max_pool_2d
  22. from tflearn.layers.normalization import local_response_normalization
  23. from tflearn.layers.estimator import regression
  24. class MnistModel:
  25. def __init__(self, _model_path):
  26. network = input_data(shape=[None, 28, 28, 1], name='input')
  27. network = conv_2d(network, 32, 3, activation='relu', regularizer="L2")
  28. network = max_pool_2d(network, 2)
  29. network = local_response_normalization(network)
  30. network = conv_2d(network, 64, 3, activation='relu', regularizer="L2")
  31. network = max_pool_2d(network, 2)
  32. network = local_response_normalization(network)
  33. network = fully_connected(network, 128, activation='tanh')
  34. network = dropout(network, 0.8)
  35. network = fully_connected(network, 256, activation='tanh')
  36. network = dropout(network, 0.8)
  37. network = fully_connected(network, 10, activation='softmax')
  38. network = regression(network, optimizer='adam', learning_rate=0.01,
  39. loss='categorical_crossentropy', name='target')
  40. self.model = tflearn.DNN(network)
  41. self.model.load(_model_path)
  42. def predict(self, _input_tensor):
  43. return self.model.predict(_input_tensor)
添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注