[关闭]
@NumberFairy 2017-12-07T13:13:38.000000Z 字数 1798 阅读 1079

神经网络之优化算法

神经网络


mini-batch梯度下降

当我们的training set比较小的时候,我们经常利用向量化的方法一次性把整个trainjing set都投入到训练中,当然,当数据量小的时候当然可以这么用;但是如果数据量比较大的时候呢?(这个数据量大一般指大于2000)我们如果还是采用一次性把所有的training set投入到训练过程中的话会造成计算量特别大,训练速度急剧变慢。这个时候我们可以考虑将training set进行分批处理。加入我们的training set有500000条,我们假设分割的batch-size是1000,也就是说我们将training set分割为了500个batches;对于一次迭代(one epoch),我们是要遍历整个training set的,很明显,一次迭代中会夹杂有500次循环;最后,在迭代每一个batch的时候就和梯度下降没有什么区别了,计算forward propagation,cost,backward propagation, update parameters。(what's more)在处理大批量的数据时一般用的都是mini-batch梯度下降,然后结合下面的一种或两种算法一并使用来提高训练速度和收敛速度。

注意:batch-size的大小一般都是2的n次方,这和计算机的内部存储机制,计算机制有关吧(我猜),一般的batch-size大小为:64, 128, 256, 512, 1024。一般就前四个比较常见。

指数加权平均数

公式:Vt = β * V(t-1) + (1-β)θt
指数加权平均数相比与我们平常用的avg=sum/num,要有运算上的优势,在计算机内部计算二者的值,指数加权平均数计算快,占用的内存少。这里我们之所以选择这种平均数的计算方法就是因为我们在进行神经网络运算的时候涉及到的多是向量之间的相乘,运算量大,且占内存空间,我们用这种就算平均数的方法不用记录每一个中间值,省内存。(我觉得)另外,后面的几个优化算法都用到了指数加权平均数的内容。

Momentum(动量梯度下降)

该方法核心利用指数加权平均数,用平均值来取代第i次迭代的权值(不是很准确,但是的确削弱了权值的波动幅度)。直接给出梯度计算的公式:

V_dw = β1 * V_dw + (1-β1) * dw
V_db = β1 * V_db + (1-β1) * db
W = W - α * V_dw
b = b - α * V_db

采用Momentum算法来训练数据要比不采用该算法效果好,收敛速度快;根据数据模拟出来的decided-boundary也更合理平滑,上述给出了parameter的更新方法,如果研究者比较关心模拟出来的曲线的初始值,我们需要考虑进行偏差修正(因为我们初始化V_dw=0),最终的parameters如下:

V_dw = β1 * V_dw + (1-β1) * dw
V_db = β1 * V_db + (1-β1) * db
V_dw = V_dw / (1-β1^t)  -----------t:第t次迭代
V_db = V_db / (1-β1^t)
W = W - α * V_dw
b = b - α * V_db

RMSprop算法(均方根)

直接上梯度计算公式:

S_dw = β2 * S_dw + (1-β2)dw^2
S_db = β2 * S_db + (1-β2)db^2
W = W - α * (dw/squar(S_dw))  -------squar:数学符号-根号
b = b - α * (db/squar(S_db))

注意:为防止上述计算W和b时分母出现0,我们可以在分母上加上一个 ε 。

Adam算法(Adaptive Moment Estimation)

其实是Momentum和RMSprop算法的结合使用,parameters更新如下:

V_dw = β1 * V_dw + (1-β1) * dw
V_db = β1 * V_db + (1-β1) * db
S_dw = β2 * S_dw + (1-β2)dw^2
S_db = β2 * S_db + (1-β2)db^2

偏差修正:

V_dw = V_dw / (1-β1^t)  
V_db = V_db / (1-β1^t)
S_dw = S_dw / (1-β2^t)
S_db = S_db / (1-β2^t)

最后参数如下:

W = W - α * (V_dw/squar(S_dw) + ε) --------ε可根据情况加上
b = b - α * (V_db/squar(S_db) + ε)

上述出现的β1、β2、ε 在业内有特定的缺省值,一般不会修改:

β1: 0.9
β2: 0.999
ε : 1e-8
添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注