@devilloser
2017-06-01T02:54:39.000000Z
字数 2005
阅读 882
tensorflow
关于官方tensorflow的补充,出现的问题
第一步在当前文件夹下创建MNIST_data文件夹,下载数据集的四个文件
(为了不修改代码偷懒)
第二步下载input_data.py文件
input_data.py文件如下:
"""Functions for downloading and reading MNIST data."""from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport gzipimport osimport tempfileimport numpyfrom six.moves import urllibfrom six.moves import xrange # pylint: disable=redefined-builtinimport tensorflow as tffrom tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
第三步
import input_datamnist=input_data.read_data_sets('MNIST_data/',one_hot=True)
InteractiveSession
sess = tf.InteractiveSession()a = tf.constant(5.)b = tf.constant(6.)c = a * bprint(c.eval())sess.close()
Session
sess = tf.Session()a = tf.constant(5.)b = tf.constant(6.)c = a * bprint(c.eval(session=sess))sess.close()
with语句
with tf.Session() as sess:a = tf.constant(5.)b = tf.constant(6.)c = a * bprint(c.eval())sess.close()
argmax(f(t))取f(t)最大时的t值
tensor.eval等同于session.run(tensor)
tf.cast将布尔值列表例如 [True,False, True, True]。可以使用tf.cast()函数将其转换为[1,0,1,1],以方便准确率的计算(以上的是准确率为0.75)。
(1)
tf.truncated_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)
shape: 一个一维整数张量 或 一个Python数组。 这个值决定输出张量的形状。
mean: 一个零维张量或 类型属于dtype的Python值. 这个值决定正态分布片段的平均值
stddev: 一个零维张量或 类型属于dtype的Python值. 这个值决定正态分布片段的标准差。
dtype: 输出的类型.
seed: 一个Python整数. 被用来为正态分布创建一个随机种子. 详情可见set_random_seed for behavior.
name: 操作的名字 (可选参数).
(2)
tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, data_format=None, name=None)
input:[batch,in_height,in_weight,in_channels]
filter:[filter_height,filter_width,in_channels,out_channels]
strides:strides[0]=strides[3]=1
strides[1]垂直步长
strides[2]水平步长
padding:valid不padding SAME 0padding
(3)
tf.nn.max_pool(value, ksize, strides, padding, data_format='NHWC', name=None)
value : 一个形状为 [batch, height, width, channels] 且 类型为 tf.float32 的四维张量.
ksize : 一个长度大于等于 4 的整数列表. input张量的每个维度的窗格大小 .
strides : 一个长度大于等于 4 的整数列表. input的每一个维度的滑动窗格的步幅.
padding : 一个为'VALID' 或 'SAME' 的字符串. 填充算法. 详情可见注解
data_format : 字符串. 目前支持 'NHWC' 和 'NCHW'.
name : 可选参数,操作的名称.