Hey, this is why LSTM works!
Introduction
首先谈一谈我对于神经网络结构改造的想法,其实和物理、化学这些理科不一样,很难说有什么支撑性的理论基础来引导模型的设计,深度学习更看重的是实际工程效果怎么样。据我所知对于文本和图像这种数据,人们还并没有那种完全式的掌握和理解,其实也是通过各种模型工具来窥知一二。所以深度学习的理论分析大多都是很泛的,因为针对某类具体问题压根没有对应的完备理论性知识。
深度学习的思维是针对已有模型的问题,通过引入一些归纳偏置(可能来源于数据的特性,可能来源于对模型的理解)来增加结构,去尝试work不work,如果确实有效再看看是不是真的符合预期假设的那样。所以这个过程更类似于engineer,而不是theory analysis。因此就会导致有些有效的网络模型看起来就感觉没有物理学中的美感,更像是各种naive idea和trick叠加的产物,研究的流程就是先跑个好结果然后再编个好故事,老是被其他专业诟病成炼丹。(摊手┑( ̄- ̄)┍)
平时接触最多的是把已有的网络模型魔改去解决具体问题,而设计一个全新的网络结构去解决一类问题看起来显然更有趣也是难度最大的,目前已有的网络大致可以分为MLP、CNN、RNN、Transformer、GNN、GAN、VAE 这几类,这篇blog将会详细讨论RNN中最具有代表性的LSTM网络设计。
不知道其他人有没有这种感觉,当接触除LSTM以外的模型时,学起来都感觉很自然,学LSTM的时候,就像进入了magical world一样(cmu经典配图meme),完全不觉得make sense,以至于我隔了一段时间没看网络结构,再回过头一看还是觉得很怪。每当我有这种感觉,我就知道是时候该和相应部分知识正义切割了,于是在搜集大量资料paper后,我将我的思考呈现在这篇博客中,不仅仅只是理解how lstm works,更是why lstm works。
So What’s wrong with RNN?
RNN是一个图灵完备模型,理论上是可以模拟任意可计算函数的,详细的解释可以看(link),这和MLP是通用函数拟合器很类似,上限都是天花板,但是就训练不出那么好的效果。这主要是由两部分因素引起的:
- 第一是采用BPTT的训练算法对于普通RNN梯度更新会出现问题
- 第二是模型没有特别针对时序任务的特性加结构,并不能很好的捕捉到特征,而且利用已有的梯度对参数更新时,会出现weight conflict,feature capture和memorization冲突的问题。
关于第二点再来解释一下,RNN这个模型本身就是利用不断迭代来模拟时序,也就是把时序不同远近对于输出/参数更新的影响反映到了这种耦合了很多因素的过程中,实际这样没有明确结构设计的策略对于某些具体的任务表现还是很弱的。
这里值得提一下的是,不显示的加结构还想获得足够好的效果,在实践中一般是不太可能的,cmu的教授也是这么说的。毕竟No Free Lunch :)
Vanishing and Exploding Gradients
首先可以给出a mathematical proof of a sufficient condition for vanishing sensitivity in vanilla RNNs: (link),总结起来就是权重矩阵的二范数不能太小,由此可以引出一个weight initialization的技术,也是上面link中提到的,用来缓解一下梯度消失的问题。
对于DNN,梯度消失和爆炸作为反向传播链式法则的产物同样也会出现在RNN中,而RNN的shared parameter结构更加深了梯度更新的弊病,因为会出现weight conflict的现象,也就是序列前后对同一参数的更新作用可能是相反的,比如memory和forget操作。
对于梯度爆炸可以采用clip的方法,而梯度消失的解决方法则要困难得多,比较常用的方法是增加regularization term(batch norm等),但是在RNN实际问题中我们并不希望梯度是一直受regularizer的限制,有时我们希望梯度表现为vanish或者increase,而仅仅利用regularization的效果也确实不是特别好。所以问题就变得tricky起来了。
Information morphing
首先需要思考的是,在RNN中information morphing是被怎么定义的,我认为至少可以从两个角度来看,也就是forward和backward part。
在cmu的课上有对RNN进行stability analysis,同时也提到了RNN的记忆问题,也就是在forward的过程中,由于activation和weight的不断迭代,网络会逐渐遗忘input的信息,最后的输出表现为只和weight、activation相关,而不是和input相关。所以一个优化的思路是:既然如此干脆就直接不加weight和activation直接造一条memory line,然后针对不同的功能进行解耦,增加新的结构。
此外在cmu的课上,professor还强调tanh作为激活函数是可以memorize最久的,也就是Bipolar activation function,所以后面关于"write"信息增量的时候,采用的就是tanh函数。
在backward part中会出现经常听说的weight conflict现象,对不同方向的变化并不能很好分开处理,但在这里我更推荐看这篇论文的例子:(link),比较直观反映由于模型本身的缺陷造成的information morphing。
More detail explanation
上面说的一些概念可能看这篇博客的人会不熟,包括我自己后面大概率也会忘,所以这里来举个例子:
引用 Quokka 大佬的回答,问题来源:(link)
input weight conflict大概意思就是说普通的RNN无法长期保持信息。因为不同时间步里输入到隐层之间的变换矩阵是共享的(相同的),但是不同时间步的输入一般不同,所以对某个隐层神经元 j 来说,它不太可能在很多时间步里都保持激活的状态,所以就没法长期保存信息。
举个例子,假设我们想用RNN来检测引用(其实就是一个序列标注问题:输入一句话,把引号里的部分标记为正类,引号外的部分标记为负类),当遇到左引号时,某个神经元被激活,此后我们希望这个神经元一直能保持激活状态,因为它的激活可以告诉我们此时模型正在处理左引号后面的内容(也就是在引号里面)。但是因为每一个时间步都有输入,所以它的状态就会被这些输入影响,而左引号对它的影响慢慢减弱,相当于它“忘了”自己曾经遇到过一个左引号。
而如果使用LSTM的话,因为有一个 input gate 的存在,在左引号之后的输入都被 input gate 屏蔽掉,直到右引号出现才把 input gate 打开,写入新信息,熄灭这个神经元,从而输出状态就可以由这个神经元的状态直接对应得到(神经元点亮说明当前单词是被引用的话,熄灭说明当前单词不在引文里),实现准确的标注。
Core idea
大部分的教程都在强调从梯度的角度来改进RNN,并把这点当成work的根本原因。很显然这是很naive的想法,正确的思路是发现了RNN出现这种问题,并通过增加information invariant结构时考虑到gradient的传播问题。
Core idea:模型需要做到具有information invariant的能力,同时也能根据输入的变化进行相应的调整,也就是我自己总结的如下两点:
- 
information invariant = invariant + selectivity 
- 
learning ability = valid gradient update 
第一点为了保持信息不变性,采取的策略是write,也就是增量机制:
The fundamental principle: Write it down.
To ensure the integrity of our messages in the real world, we write them down. Writing is an incremental change that can be additive (pen on paper) or subtractive (carving in rock), and which remains unchanged absent outside interference. In LSTMs, everything is written down and, assuming no interference from other state units or external inputs, carries its prior state forward.
Practically speaking, this means that any state changes are incremental, so that
但是在实际的问题中,仅仅是把所有变化以增量的方式记录也是不够的:
The fundamental challenge: Uncontrolled and uncoordinated writing.
Uncontrolled and uncoordinated writes, particularly at the start of training when writes are completely random, create a chaotic state that leads to bad results and from which it can be difficult to recover.
事实证明效果也确实不好,因为我们想要模型做到对于输入和state传递的信息有特征选择和记忆保持的功能,但实际还是会出现“input weight conflict”, “output weight conflict”, the “abuse problem”, and “internal state drift”这几个问题(这几个术语来自lstm原始论文(link)),下面我搬了原论文的解释。
Input weight conflict: For simplicity, let us focus on a single additional input weight . Assume that the total error can be reduced by switching on unit j in response to a certain input and keeping it active for a long time (until it helps to compute a desired output). Provided i is nonzero, since the same incoming weight has to be used for both storing certain inputs and ignoring others, 
 will often receive conflicting weight update signals during this time (recall that j is linear). These signals will attempt to make 
 participate in (1) storing the input (by switching on j) and (2) protecting the input (by preventing j from being switched off by irrelevant later inputs). This conflict makes learning difficult and calls for a more context-sensitive mechanism for controlling write operations through input weights.
Output weight conflict: Assume j is switched on and currently stores some previous input. For simplicity, let us focus on a single additional outgoing weight . The same 
 has to be used for both retrieving j’s content at certain times and preventing j from disturbing k at other times. As long as unit j is nonzero, 
 will attract conflicting weight update signals generated during sequence processing. These signals will attempt to make 
 participate in accessing the information stored in j and—at different times—protecting unit k from being perturbed by j. For instance, with many tasks there are certain short time-lag errors that can be reduced in early training stages. However,at later training stages, j may suddenly start to cause avoidable errors in situations that already seemed under control by attempting to participate in reducing more difficult long-time-lag errors. Again, this conflict makes learning difficult and calls for a more context-sensitive mechanism for controlling read operations through output weights.
Input and output weight conflicts are not specific for long time lags; they occur for short time lags as well. Their effects, however, become particularly pronounced in the long-time-lag case. As the time lag increases, stored information must be protected against perturbation for longer and longer periods, and, especially in advanced stages of learning, more and more already correct outputs also require protection against perturbation
所以采取的策略是把不同的功能解耦,通过加入新的结构来解决这个冲突问题,也就是下文将详细介绍的具有选择性的门机制。
先要获得更多insights,可以看这篇文章:RNN 中学习长期依赖的三种机制 - Quokka的文章 - 知乎 https://zhuanlan.zhihu.com/p/34490114
Building up LSTM
Cambridge按照原始论文的思路整理的很好的PPT:(link)
Math of LSTM
主要介绍Backprop的part,比较好的资料推荐:(link 1),(link 2) 以及cmu 21年的lecture (link)
此外就是为什么LSTM通过门机制能解决梯度消失的问题:(link) and (link)
Explanation
对于模型的表达能力,要从两方面去看:理论上分析和实际的效果。
Theory Explanation
我们聚焦于Gate机制的表达能力,
暂时没找到很好的资料,鸽一会
word -> 信号?
Experiments Explanation
我想强调的是即使上面我们分析并构建LSTM的idea都很合理,那也是先经过实验不断尝试得到的结果,不仅仅靠理论上的intuition,此外还需验证结果确实符合我们的假设。
首先介绍一个toy example,是李宏毅老师的,大概在28分钟的时候:(link)
其次是Harvard nlp组的LSTM可视化工作:(link),中文解释:(link)
以及Inner-process visualization of hidden states in recurrent neural networks (link) or (link)
此外补充关于RNN解释的工作:(link) and (link)
Improvements
介绍一下LSTM变体的改进思路。
https://zhuanlan.zhihu.com/p/34500721
Code of GRU
来自CMU的hw,给出的是GRU的例子,LSTM类似:
值得说明的是,博客介绍反向传播的部分虽然是给出了完整的数学推导,但是实际写代码的时候还是采取了迭代的策略,没有复杂的公式,只有精简的逻辑。
| 1 | import numpy as np | 
再贴一个GRU使用的案例:
| 1 | import numpy as np | 
Reference
cmu 11-785
知乎很多大佬的回答
https://r2rt.com/written-memories-understanding-deriving-and-extending-the-lstm.html


.png)







