LSTM模型

pwd`×续缘` 2024-08-17 16:31:03 阅读 76

目录

前言

1 、传统RNN优缺点

1 传统RNN的优势:

2 传统RNN的缺点:

3 梯度消失或爆炸介绍

 2、LSTM介绍

3、LSTM的内部结构

3.1 LSTM结构分析

3.2 使用Pytorch构建LSTM模型

3.3 LSTM优缺点

4、小节 



前言

🏷️在介绍LSTM模型之前,我们再次见一下CNN是什么?RNN主要用于序列处理,比如机器翻译,这种输入输出序列之间具有高度的相关性,RNN可以model这种关系,总结一下,按照输入输出的类型,RNN可以做以下几个事情:

one-to-one: CNNone-to-many: Image Captionmany-to-one: MNIST(glimpse输入)字符分类many-to-many: 机器翻译

🏷️接下来我们先简单介绍传统的RNN模型,了解其优缺点


1 、传统RNN优缺点

1 传统RNN的优势:

由于内部结构简单, 对计算资源要求低, 相比之后我们要学习的RNN变体:LSTM和GRU模型参数总量少了很多, 在短序列任务上性能和效果都表现优异.

2 传统RNN的缺点:

传统RNN在解决长序列之间的关联时, 通过实践,证明经典RNN表现很差, 原因是在进行反向传播的时候, 过长的序列导致梯度的计算异常, 发生梯度消失或爆炸.

3 梯度消失或爆炸介绍

根据反向传播算法和链式法则, 梯度的计算可以简化为以下公式

Dn=σ′(z1)w1⋅σ′(z2)w2⋅⋯⋅σ′(zn)wn𝐷𝑛=𝜎′(𝑧1)𝑤1⋅𝜎′(𝑧2)𝑤2⋅⋯⋅𝜎′(𝑧𝑛)𝑤𝑛

其中sigmoid的导数值域是固定的, 在[0, 0.25]之间, 而一旦公式中的w也小于1, 那么通过这样的公式连乘后, 最终的梯度就会变得非常非常小, 这种现象称作梯度消失. 反之, 如果我们人为的增大w的值, 使其大于1, 那么连乘够就可能造成梯度过大, 称作梯度爆炸.

梯度消失或爆炸的危害:

如果在训练过程中发生了梯度消失,权重无法被更新(梯度消失概念以及权重的跟更新的知识在机器学习中已经涉及),最终导致训练失败; 梯度爆炸所带来的梯度过大,大幅度更新网络参数,在极端情况下,结果会溢出(NaN值).

 2、LSTM介绍

LSTM(Long Short-Term Memory)也称长短时记忆结构, 它是传统RNN的变体, 与经典RNN相比能够有效捕捉长序列之间的语义关联, 缓解梯度消失或爆炸现象. 同时LSTM的结构更复杂, 它的核心结构可以分为四个部分去解析:

遗忘门输入门细胞状态输出门

3、LSTM的内部结构

3.1 LSTM结构分析

结构解释图:

黄色方块:表示一个神经网络层(Neural Network Layer);

粉色圆圈:表示按位操作或逐点操作(pointwise operation),例如向量加和、向量乘积等;

单箭头:表示信号传递(向量传递);

合流箭头:表示两个信号的连接(向量拼接);

分流箭头:表示信号被复制后传递到2个不同的地方

遗忘门部分结构图与计算公式:

遗忘门结构分析:

与传统RNN的内部结构计算非常相似, 首先将当前时间步输入x(t)与上一个时间步隐含状态h(t-1)拼接, 得到[x(t), h(t-1)], 然后通过一个全连接层做变换, 最后通过sigmoid函数进行激活得到f(t), 我们可以将f(t)看作是门值, 好比一扇门开合的大小程度, 门值都将作用在通过该扇门的张量, 遗忘门门值将作用的上一层的细胞状态上, 代表遗忘过去的多少信息, 又因为遗忘门门值是由x(t), h(t-1)计算得来的, 因此整个公式意味着根据当前时间步输入和上一个时间步隐含状态h(t-1)来决定遗忘多少上一层的细胞状态所携带的过往信息.

🏷️这里面的计算公式,包括接下来我们也要介绍的,有很多与RNN的计算公式相似,我们也可以通过RNN的思想去一步一步理解每一个结构的含义以及作用

遗忘门内部结构过程演示:

激活函数sigmiod的作用:

用于帮助调节流经网络的值, sigmoid函数将值压缩在0和1之间.输入门部分结构图与计算公式:

输入门结构分析:

我们看到输入门的计算公式有两个, 第一个就是产生输入门门值的公式, 它和遗忘门公式几乎相同, 区别只是在于它们之后要作用的目标上. 这个公式意味着输入信息有多少需要进行过滤. 输入门的第二个公式是与传统RNN的内部结构计算相同. 对于LSTM来讲, 它得到的是当前的细胞状态, 而不是像经典RNN一样得到的是隐含状态.

输入门内部结构过程演示:

细胞状态更新图与计算公式:

细胞状态更新分析:

细胞更新的结构与计算公式非常容易理解, 这里没有全连接层, 只是将刚刚得到的遗忘门门值与上一个时间步得到的C(t-1)相乘, 再加上输入门门值与当前时间步得到的未更新C(t)相乘的结果. 最终得到更新后的C(t)作为下一个时间步输入的一部分. 整个细胞状态更新过程就是对遗忘门和输入门的应用.

细胞状态更新过程演示:

输出门部分结构图与计算公式:

输出门结构分析:

输出门部分的公式也是两个, 第一个即是计算输出门的门值, 它和遗忘门,输入门计算方式相同. 第二个即是使用这个门值产生隐含状态h(t), 他将作用在更新后的细胞状态C(t)上, 并做tanh激活, 最终得到h(t)作为下一时间步输入的一部分. 整个输出门的过程, 就是为了产生隐含状态h(t).

输出门内部结构过程演示:

3.2 使用Pytorch构建LSTM模型

位置: 在torch.nn工具包之中, 通过torch.nn.LSTM可调用.

nn.LSTM类初始化主要参数解释:

input_size: 输入张量x中特征维度的大小.hidden_size: 隐层张量h中特征维度的大小.num_layers: 隐含层的数量.bidirectional: 是否选择使用双向LSTM, 如果为True, 则使用; 默认不使用.

nn.LSTM类实例化对象主要参数解释:

input: 输入张量x.h0: 初始化的隐层张量h.c0: 初始化的细胞状态张量c.

nn.LSTM使用示例:

<code># 定义LSTM的参数含义: (input_size, hidden_size, num_layers)

# 定义输入张量的参数含义: (sequence_length, batch_size, input_size)

# 定义隐藏层初始张量和细胞初始状态张量的参数含义:

# (num_layers * num_directions, batch_size, hidden_size)

>>> import torch.nn as nn

>>> import torch

>>> rnn = nn.LSTM(5, 6, 2)

>>> input = torch.randn(1, 3, 5)

>>> h0 = torch.randn(2, 3, 6)

>>> c0 = torch.randn(2, 3, 6)

>>> output, (hn, cn) = rnn(input, (h0, c0))

>>> output

tensor([[[ 0.0447, -0.0335, 0.1454, 0.0438, 0.0865, 0.0416],

[ 0.0105, 0.1923, 0.5507, -0.1742, 0.1569, -0.0548],

[-0.1186, 0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]],

grad_fn=<StackBackward>)

>>> hn

tensor([[[ 0.4647, -0.2364, 0.0645, -0.3996, -0.0500, -0.0152],

[ 0.3852, 0.0704, 0.2103, -0.2524, 0.0243, 0.0477],

[ 0.2571, 0.0608, 0.2322, 0.1815, -0.0513, -0.0291]],

[[ 0.0447, -0.0335, 0.1454, 0.0438, 0.0865, 0.0416],

[ 0.0105, 0.1923, 0.5507, -0.1742, 0.1569, -0.0548],

[-0.1186, 0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]],

grad_fn=<StackBackward>)

>>> cn

tensor([[[ 0.8083, -0.5500, 0.1009, -0.5806, -0.0668, -0.1161],

[ 0.7438, 0.0957, 0.5509, -0.7725, 0.0824, 0.0626],

[ 0.3131, 0.0920, 0.8359, 0.9187, -0.4826, -0.0717]],

[[ 0.1240, -0.0526, 0.3035, 0.1099, 0.5915, 0.0828],

[ 0.0203, 0.8367, 0.9832, -0.4454, 0.3917, -0.1983],

[-0.2976, 0.7764, -0.0074, -0.1965, -0.1343, -0.6683]]],

grad_fn=<StackBackward>)

3.3 LSTM优缺点

LSTM优势:

LSTM的门结构能够有效减缓长序列问题中可能出现的梯度消失或爆炸, 虽然并不能杜绝这种现象, 但在更长的序列问题上表现优于传统RNN.

LSTM缺点:

由于内部结构相对较复杂, 因此训练效率在同等算力下较传统RNN低很多.

4、小节 

LSTM的内部结构可能只通过文字讲述会有些抽象,内部结构相对来说复杂,我们可以通过将其拆分一一分析,我们不难发现他和RNN算法的相同之处,本质都是相同,下节我们介绍复杂度相对来说没有那么复杂的GRU模型 



声明

本文内容仅代表作者观点,或转载于其他网站,本站不以此文作为商业用途
如有涉及侵权,请联系本站进行删除
转载本站原创文章,请注明来源及作者。