李宏毅机器学习课程笔记-9.2如何训练RNN

RNN的损失函数

仍然以Slot Filling为例,如下图所示。

img

对于1个word$x^i$,RNN输出1个one-hot编码的vector $y^i$,求$y^i$和对应label的交叉熵损失(Cross Entropy Loss),将多个word的loss求和即为RNN的损失函数。需要注意的是不能打乱word的语序,$x^{i+1}$要紧接着$x^i$输入。

确定RNN的损失函数后,RNN的训练其实也是用的梯度下降。训练前馈神经网络时我们使用有效的反向传播算法,为了方便地训练RNN,我们使用BPTT。基于BP,BPTT(Backpropagation Through Time)考虑了时间维度的信息。

RNN的Error Surface

RNN的Error Surface如下图所示,其中$z$轴代表loss,$x$轴和$y$轴代表两个参数$w_1$和$w_2$。可以看出,RNN的Error Surface在某些地方非常平坦,在某些地方又非常的陡峭。这样的Error Surface导致在训练RNN时loss剧烈变化

img

问题出现的原因

既然RNN的Error Surface中有这么平滑的地方,那会不会是sigmoid激活函数造成的梯度消失呢?原因并不是sigmoid,如果是的话,那换成ReLU就可以,但把sigmoid换成ReLU之后,效果反而更差了。那到底为什么会有非常陡峭和非常平滑的地方呢?

img

如上图所示,假设某RNN只含1个神经元,并且该神经元是Linear的,input和output的weight都是1,没有bias,memory传递的weight是$w$,输入序列为[1, 0, 0, 0, …, 0],所以$y^{1000}=w^{999}$。

现在我们考虑loss关于参数$w$的梯度,当$w:\ 1\ =>\ 1.01$时,可知$y^{1000}:\ 1\ =>\ 20000$,此时梯度很大;当$w:\ 0.99\ =>\ 0.01$时,可知$y^{1000}$几乎没有变化,此时梯度很小。

从该例中可知,RNN的Error Surface中的“悬崖”出现的原因是,关于memory的参数$w$的作用随着时间增加不断增强,导致RNN出现梯度消失或梯度爆炸的问题

处理方法

如何解决RNN梯度消失或梯度爆炸的问题?可以通过Clipping进行处理,Clipping的效果是使梯度不超过某个阈值,即当梯度即将超过某个阈值(比如15)时,就将梯度赋值为该阈值。

LSTM

有什么更好的方法可以解决RNN的Error Surface中的问题呢?LSTM就是使用最广泛的技巧,它可以“删除”Error Surface中比较平坦的部分,也就解决了梯度消失的问题,但它无法解决梯度爆炸的问题。正因如此,训练LSTM时需要将学习率调得特别小。

LSTM为什么可以解决RNN中梯度消失的问题呢,因为RNN和LSTM对memory的处理是不同的(LSTM有forget gate)。在RNN中,每个时间点memory中的旧值都会被新值覆盖,导致参数$w$对memory的影响每次都被清除,进而引发梯度消失。在LSTM中,每个时间点memory里的旧值会乘以$f(g(f))$再与新值相加,只有在forget gate被关闭时参数$w$对memory的影响才会被清除,在forget gate被打开时参数$w$对memory的影响就会通过累加得到保留,因此不会出现梯度消失的问题。

LSTM在1997年被提出,第1版的LSTM被提出就是为了解决梯度消失的问题,但这1版本是没有forget gate的,forget gate是后来才加上去的。也有1种说法是,在训练LSTM时需要给forget gate特别大的bias,以确保forget gate在多数情况下是开启的。

GRU(Gated Recurrent Unit, Cho, EMNLP’14)

GRU比LSTM更简单,GRU只有2个gate,因此需要更少的参数量、鲁棒性更好、更不容易过拟合。GRU的基本思路是“旧的不去,新的不来”,GRU把input和forget gate联动起来,当forget gate把memory中的值清空时,input gate才会打开然后放入新的值。

Clockwise RNN(Jan Koutnik, JMLR’14)

SCRN(Structrally Constrained Recurrent Network, Tomas Mikolov, ICLR’15)

Vanilla RNN Initialized with Identity Matrix + ReLU(Quoc V.Le, arXiv’15)


Github(github.com):@chouxianyu

Github Pages(github.io):@臭咸鱼

知乎(zhihu.com):@臭咸鱼

博客园(cnblogs.com):@臭咸鱼

B站(bilibili.com):@绝版臭咸鱼

微信公众号:@臭咸鱼

转载请注明出处,欢迎讨论和交流!