[关闭]
@cleardusk 2015-11-27T11:38:23.000000Z 字数 3513 阅读 1756

DL 在线书籍源码阅读(一)

GjzCVCode


IO 部分

黑盒测试

结合 doc strings,进行一些简单的测试。

  1. training_data, validation_data, test_data = mnist_loader.load_data_wrapper()

测试代码

  1. In [15]: training_data[0][0][783]
  2. Out[15]: array([ 0.], dtype=float32)
  3. In [16]: training_data[0][0][784]
  4. ---------------------------------------------------------------------------
  5. IndexError Traceback (most recent call last)
  6. <ipython-input-16-dd65c757aa33> in <module>()
  7. ----> 1 training_data[0][0][784]
  8. IndexError: index 784 is out of bounds for axis 0 with size 784

结论
training_data: [(x,y),(x,y),(x,y),(x,y),(x,y)...(x,y)]
training_data 是一个长度为 50000 的 list,element 是 tuple (x, y),x 是 28*28=784 dimensions 的 image,在 Python 中是一维的 like-array,y 是 10 dimensions,比如 [0,1,0,0,0,0,0,0,0,0],代表 1,可依此类推。

validation_data, test_data 与 training_data 结构相同,除了数量不同,三者分别是 50000, 10000, 10000。强调一点,训练的时候应该拿 validation_data 做测试,不能拿 test_data,用 test_data 容易 overfit,结果不靠谱。

代码部分

load_wrapper() 调用了 load_data(),结合 doc strings 对 load_data()进行简单测试。

  1. tr_d, va_d, te_d = mnist_loader.load_data()
  2. In [26]: tr_d[0][49999][783]
  3. Out[26]: 0.0
  4. In [27]: tr_d[1]
  5. Out[27]: array([5, 0, 4, ..., 8, 4, 8])
  6. In [28]: tr_d[1][0]
  7. Out[28]: 5

与 load_wrapper() 不同的是,image 和数字 label 没有组成一个 tuple 放在 list 内,此外,image, digit label 不是通过 np.array 类型来存储的。其原因待会分析。

load_data()

  1. def load_data():
  2. f = gzip.open('../data/mnist.pkl.gz', 'rb')
  3. training_data, validation_data, test_data = cPickle.load(f)
  4. f.close()
  5. return (training_data, validation_data, test_data)

用了 cPickle, gzip 两个标准库。

  1. In [34]: import cPickle
  2. In [35]: help(cPickle)
  3. NAME
  4. cPickle - C implementation and optimization of the Python pickle module.
  5. FUNCTIONS
  6. Pickler(...)
  7. Pickler(file, protocol=0) -- Create a pickler.
  8. Unpickler(...)
  9. Unpickler(file) -- Create an unpickler.
  10. dump(...)
  11. dump(obj, file, protocol=0) -- Write an object in pickle format to the given file.
  12. See the Pickler docstring for the meaning of optional argument proto.
  13. dumps(...)
  14. dumps(obj, protocol=0) -- Return a string containing an object in pickle format.
  15. See the Pickler docstring for the meaning of optional argument proto.
  16. load(...)
  17. load(file) -- Load a pickle from the given file
  18. loads(...)
  19. loads(string) -- Load a pickle from the given string

我贴了部分 cPickle 的帮助文档,cPickle 其实是读写文件的一个 module,支持文本、二进制 mode,但是是用 c 实现的,c 的读写效率在高级语言中应该是最高的,我曾对比过 c 和 c++ 的文件读写函数效率,c 大概是 c++ 十倍以上,matlab 中读数据速度最快的函数其实就是 c 的函数。Python 自带的 open 相对 c 肯定很慢。

为了压缩数据,这里用了 gzip 格式。

  1. In [36]: import gzip
  2. In [37]: help(gzip)
  3. NAME
  4. gzip - Functions that read and write gzipped files.
  5. FUNCTIONS
  6. open(filename, mode='rb', compresslevel=9)
  7. Shorthand for GzipFile(filename, mode, compresslevel).
  8. The filename argument is required; mode defaults to 'rb'
  9. and compresslevel defaults to 9.

所以这几行代码就很清楚了:用 gzip 库打开 gzip(*.gz) 压缩格式数据,返回 file instance f,再用 cPickle 库 load。

  1. f = gzip.open('../data/mnist.pkl.gz', 'rb')
  2. training_data, validation_data, test_data = cPickle.load(f)
  3. f.close()
  4. return (training_data, validation_data, test_data)

load_data_wrapper()
这个函数的作用也变得清楚了,就是将 cPickle 格式的数据转换成 numpy.array 格式存储,并放到一个 list 里面。Python 的数据结构及语法细节就不啰嗦了。有一点值得注意的是 vectorized_result() 这个函数,代码我没贴出来,作用是将一个 digit,转换为一个拓展后的 array,比如将数字 1,转换为 array([0,1,0,0,0,0,0,0,0,0])

  1. tr_d, va_d, te_d = load_data()
  2. training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]]
  3. training_results = [vectorized_result(y) for y in tr_d[1]]
  4. training_data = zip(training_inputs, training_results)
  5. validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]]
  6. validation_data = zip(validation_inputs, va_d[1])
  7. test_inputs = [np.reshape(x, (784, 1)) for x in te_d[0]]
  8. test_data = zip(test_inputs, te_d[1])
  9. return (training_data, validation_data, test_data)

小结

以上的代码最终的目标是将数据转换为适合 NN 读取以及人便于理解的格式,其实就是在效率、抽象性、可读性中找到一个平衡。比如 vectorized_result() 函数,将 digit 转换为 numpy.array 格式,是为了适应 NN 的 output layer。数据 IO 部分必须考虑到 network model 的架构,在这方面,caffe 应该是做了更高级的抽象,caffe 用 protobuf 库,考虑的应该也是它的效率、抽象性、可读性。

还有一点,我分析的只是读取部分,要完全理解掌握 cPickle, gzip,还需要测试 write data 的部分。这个地方我没花时间,因为需要花时间的细节部分实在太多了,但我大概知道 how it works,也知道怎么去理解它如何精确地工作。


添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注