飞道的博客

1014长短期记忆网络(LSTM)

364人阅读  评论(0)

长短期记忆网络(LSTM)

  • 长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题,解决这个问题最早的方法之一就是 LSTM
  • 发明于90年代
  • 使用的效果和 GRU 相差不大,但是使用的东西更加复杂

  • 长短期记忆网络的设计灵感来自于计算机的逻辑门
  • 长短期记忆网络引入了记忆元(memory cell),或简称为单元(cell)(有些文献认为记忆元是隐状态的一种特殊类型,它们与隐状态具有相同的形状,其设计目的是用于记录附加的信息)
  • 长短期记忆网络有三个门:忘记门(重置单元的内容,通过专用机制决定什么时候记忆或忽略隐状态中的输入)、输入门(决定何时将数据读入单元)、输出门(从单元中输出条目),门的计算和 GRU 中相同,但是命名不同
  • 忘记门(forget gate):将值朝 0 减少
  • 输入门(input gate):决定是否忽略掉输入数据
  • 输出门(output gate):决定是否使用隐状态

  • 类似于门控循环单元,当前时间步的输入前一个时间步的隐状态作为数据送入长短期记忆网络的门中,由三个具有 sigmoid 激活函数的全连接层处理,以计算输入门、遗忘门和输出门的值(这三个门的值都在 0~1 的范围内)

候选记忆单元(candidate memory cell)

  • 候选记忆元的计算与输入门、遗忘门、输出门的计算类似,但是使用了 tanh 函数作为激活函数,函数的值在 -1~1 之间

记忆单元

  • 在长短期记忆网络中,通过输入门遗忘门来控制输入和遗忘(或跳过):输入门 It 控制采用多少来自 Ct tilde 的新数据,而遗忘门 Ft 控制保留多少过去的记忆元 C(t-1) 的内容
  • 如果遗忘门始终为 1 且输入门始终为 0 ,则过去的记忆元 C(t-1) 将随时间被保存并传递到当前时间步(引入这种设计是为了缓解梯度消失的问题,并更好地捕获序列中的长距离依赖关系)
  • 上一时刻的记忆单元会作为状态输入到模型中
  • LSTM 和 RNN/GRU 的不同之处在于: LSTM 中的状态有两个, C 和 H

隐状态

  • 在长短期记忆网络中,隐状态 Ht 仅仅是记忆元 Ct 的 tanh 的门控版本,因此确保了 Ht 的值始终在 -1~1 之间
  • tanh 的作用:将 Ct 的值限制在 -1 和 1 之间
  • Ot 控制是否输出, Ot 接近 1 ,则能有效地将所有记忆信息传递给预测部分; Ot 接近 0 ,表示丢弃当前的 Xt 和过去所有的信息,只保留记忆元内的所有信息,而不需要更新隐状态

总结

1 LSTM 和 GRU 所想要实现的效果是差不多的,但是结构更加复杂

  • C :一个数值可能比较大的辅助记忆单元
  • C 中包含两项: 当前的 Xt过去的状态(在 GRU 中只能二选一,这里可以实现两个都选)

2 长短期记忆网络包含三种类型的门:输入门遗忘门输出门

3 长短期记忆网络的隐藏层输出包括“隐状态”“记忆元”。只有隐状态会传递到输出层,而记忆元完全属于内部信息

4 长短期记忆网络可以缓解梯度消失和梯度爆炸

5 长短期记忆网络是典型的具有重要状态控制的隐变量自回归模型。多年来已经提出了其他许多变体,例如,多层、残差连接、不同类型的正则化。但是由于序列的长距离依赖性,训练长短期记忆网络和其他序列模型(如门控循环单元)的成本较高


代码:


  
  1. import torch
  2. from torch import nn
  3. from d2l import torch as d2l
  4. batch_size, num_steps = 32, 35
  5. train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
  6. def get_lstm_params( vocab_size, num_hiddens, device):
  7. num_inputs = num_outputs = vocab_size
  8. def normal( shape):
  9. return torch.randn(size=shape, device=device) * 0.01
  10. def three():
  11. return (normal(
  12. (num_inputs, num_hiddens)), normal((num_hiddens, num_hiddens)),
  13. torch.zeros(num_hiddens, device=device))
  14. W_xi, W_hi, b_i = three()
  15. W_xf, W_hf, b_f = three()
  16. W_xo, W_ho, b_o = three()
  17. W_xc, W_hc, b_c = three()
  18. W_hq = normal((num_hiddens, num_outputs))
  19. b_q = torch.zeros(num_outputs, device=device)
  20. params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q]
  21. for param in params:
  22. param.requires_grad_( True)
  23. return params
  24. # 初始化函数
  25. def init_lstm_state( batch_size, num_hiddens, device):
  26. return (torch.zeros((batch_size, num_hiddens), device=device),
  27. torch.zeros((batch_size, num_hiddens),device=device))
  28. # 实际模型
  29. def lstm( inputs, state, params):
  30. [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] = params
  31. (H, C) = state
  32. outputs = []
  33. for X in inputs:
  34. I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
  35. F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
  36. O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
  37. C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
  38. C = F * C + I * C_tilda
  39. H = O * torch.tanh(C)
  40. Y = (H @ W_hq) + b_q
  41. outputs.append(Y)
  42. return torch.cat(outputs, dim= 0), (H, C)
  43. # 训练
  44. vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
  45. num_epochs, lr = 500, 1
  46. model = d2l.RNNModelScratch( len(vocab), num_hiddens, device, get_lstm_params, init_lstm_state, lstm)
  47. d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)


  
  1. # 简洁实现
  2. num_inputs = vocab_size
  3. lstm_layer = nn.LSTM(num_inputs, num_hiddens)
  4. model = d2l.RNNModel(lstm_layer, len(vocab))
  5. mode = model.to(device)
  6. d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

 


  
  1. # 简洁实现
  2. num_inputs = vocab_size
  3. lstm_layer = nn.GRU(num_inputs, num_hiddens)
  4. model = d2l.RNNModel(lstm_layer, len(vocab))
  5. mode = model.to(device)
  6. d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

 


  
  1. # 简洁实现
  2. num_inputs = vocab_size
  3. lstm_layer = nn.RNN(num_inputs, num_hiddens)
  4. model = d2l.RNNModel(lstm_layer, len(vocab))
  5. mode = model.to(device)
  6. d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

 


转载:https://blog.csdn.net/m0_54028213/article/details/127319407
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场