飞道的博客

Transformer变体层出不穷,它们都长什么样?

406人阅读  评论(0)

©PaperWeekly 原创 · 作者|上杉翔二

单位|悠闲会

研究方向|信息检索


不知不觉 Transformer 已经逐步渗透到了各个领域,就其本身也产生了相当多的变体,如上图。本篇文章想大致按照这个图,选一些比较精彩的变体整理,话不多说直接开始。

Transformer-XL

论文标题:

Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context

收录会议:

ACL 2019

论文链接:

https://arxiv.org/abs/1901.02860

代码链接:

https://github.com/kimiyoung/transformer-xl

上图上标的是“Recurrence”,首先看看这篇文章聚焦的 2 个问题:

  • 虽然 Transformer 可以学习到输入文本的长距离依赖关系和全局特性,但是!需要事先设定输入长度,这导致了其对于长程关系的捕捉有了一定限制。

  • 出于效率的考虑,需要对输入的整个文档进行分割(固定的),那么每个序列的计算相互独立,所以只能够学习到同个序列内的语义联系,整体上看,这将会导致文档语意上下文的碎片化(context fragmentation)。

那么如何学习更长语义联系?

segment-level Recurrence

segment-level 循环机制。如上图左边为原始 Transformer,右边为 Transformer-XL,Transformer-XL 模型的计算当中加入绿色连线,使得当层的输入取决于本序列和上一个序列前一层的输出。这样每个序列计算后的隐状态会参与到下一个序列的计算当中,使得模型能够学习到跨序列的语义联系(看动图可能更好理解)。

是第 个 segment 的第 n 层隐向量,那么第 r+1 个的第 n 层的隐向量的计算,就是上面这套公式。

  • 其中 SG 是是 stop-gradient,不再对 的隐向量做反向传播(这样虽然在计算中运用了前一个序列的计算结果,但是在反向传播中并不对其进行梯度的更新,毕竟前一个梯度肯定不受影响)。

  • 是对两个隐向量序列沿长度 L 方向的拼接 。3 个 W 分别对应 query,key 和 value 的转化矩阵,需要注意的是!k 和 v 的 W 用的是 ,而 q 是用的 ,即 kv 是用的拼接之后的 h,而 q 用的是原始序列的信息。感觉可以理解为以原始序列查拼接序列,这样可以得到一些前一个序列的部分信息以实现跨语义。

  • 最后的公式是标准的 Transformer。

还有一点设计是,在评估预测模型的时候它是会连续计算前 L 个长度的隐向量的(训练的时候只有前一个,缓存在内存中)。

即每一个位置的隐向量,除了自己的位置,都跟下一层中前(L-1)个位置的 token 存在依赖关系,而且每往下走一层,依赖关系长度会增加(L-1),这样能使跨语义更加的深入。


只看看 XL 多头注意力的 forward 的不同地方吧。


   
  1. def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
  2.              #w是上一层的输出,r是相对位置嵌入(在下一节),r_w_bias是u,r_r_bias是v向量
  3.             qlen, rlen, bsz = w.size( 0), r.size( 0), w.size( 1)
  4.              if mems is not None: #mems就是前一些序列的向量,不为空
  5.                 cat = torch.cat([mems, w],  0) #就拼起来
  6.                  if self.pre_lnorm: #如果有正则化
  7.                     w_heads = self.qkv_net(self.layer_norm(cat)) #这个net是nn.Linear,即qkv的变换矩阵W参数
  8.                  else:
  9.                     w_heads = self.qkv_net(cat)#没有正则就直接投影一下
  10.                 r_head_k = self.r_net(r)#也是nn.Linear
  11.                 w_head_q, w_head_k, w_head_v = torch.chunk(w_heads,  3, dim= -1) #复制 3
  12.                 w_head_q = w_head_q[-qlen:] #q的W不要拼接的mems
  13.              else:#没有mems,就正常的计算
  14.                  if self.pre_lnorm:
  15.                     w_heads = self.qkv_net(self.layer_norm(w))
  16.                  else:
  17.                     w_heads = self.qkv_net(w)
  18.                 r_head_k = self.r_net(r)
  19.                 w_head_q, w_head_k, w_head_v = torch.chunk(w_heads,  3, dim= -1)
  20.             klen = w_head_k.size( 0)
  21.             #qlen是序列长度,bsz是batch size,n_head是注意力头数,d_head是每个头的隐层维度
  22.             w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
  23.             w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
  24.             w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
  25.             r_head_k = r_head_k.view(rlen, self.n_head, self.d_head)                # qlen x n_head x d_head
  26.             ####计算注意力的四个部分
  27.             #AC是指相对位置的公式里的a和c两个部分,相对位置在下一节做笔记
  28.             rw_head_q = w_head_q + r_w_bias                                         # qlen x bsz x n_head x d_head
  29.             #爱因斯坦简记法求和sum,统一的方式表示各种各样的张量运算
  30.             AC = torch.einsum( 'ibnd,jbnd->ijbn', (rw_head_q, w_head_k))             # qlen x klen x bsz x n_head
  31.             #BD是指相对位置的公式里的b和d两个部分
  32.             rr_head_q = w_head_q + r_r_bias
  33.             BD = torch.einsum( 'ibnd,jnd->ijbn', (rr_head_q, r_head_k))              # qlen x klen x bsz x n_head
  34.             BD = self._rel_shift(BD)
  35.             # [qlen x klen x bsz x n_head]
  36.             attn_score = AC + BD #最后的结果
  37.             attn_score.mul_(self.scale)#进行放缩

Relative Position Encodings

相对位置编码。原始 Transformer 采用了正弦/余弦函数来编码绝对位置信息。然而因为 Transformer-XL 会有多个句子,所以还是绝对位置,那么两个句子的相同位置是同样的编码。

比如 [0, 1, 2, 3] 在两个句子 concat 之后就变成了 [0, 1, 2, 3, 0, 1, 2, 3],句子不连续,而且每次拼的句子会不一样,也不能找到适合的绝对位置编码。所以这里使用相对位置编码。

上图是原始 Transformer 和 Transformer-XL 的比较,其中 E 表示词的 Embedding,而 U 表示绝对位置编码。这大一堆看起来奇奇怪怪,实际上 Transformer 的注意力计算是 的分解,即先编码 Q(当前词 i)和 K(其他的词 j)然后算内积,位置编码是直接 add 在词嵌入上面的。

而 Transformer-XL 的改变是:

  • 把 j 的绝对位置 U 换成了相对位置 R,该相对位置表示也是一个正弦函数表示(i 和 j 的相对位置向量,j 是之前的序列,所以相减一定是正数)。R 不是通过学习得到的,好处是预测时,可以使用比训练距离更长的位置向量。

  • 使用两个可学习参数 u 和 v 替代了中的 query i 的位置映射。这里是由于每次计算 query 向量是固定的,不需要编码。

  • 每一层的 Attention 计算都要相对位置编码。Transformer 里面只有 input 的时候会加,而 XL 需要每层。

细细思考,这 attention 的四个部分各有玄机:

  • a. 基于内容的“寻址”,即没有添加原始位置编码的原始向量,

  • b. 基于内容的位置偏置,即相对于当前内容的位置偏差,

  • c. 全局的内容偏置,用于衡量 key 的重要性,query 固定查

  • d. 全局的位置偏置,根据 query 和 key 之间的距离调整重要性,query 固定查

相对位置编码的代码为:


   
  1. class PositionalEmbedding(nn.Module):
  2.     def __init__(self, demb):
  3.         super(PositionalEmbedding, self).__init__()
  4.         self.demb = demb #编码维度
  5.         inv_freq =  1 / ( 10000 ** (torch.arange( 0.0, demb,  2.0) / demb)) #间隔频率
  6.     def forward(self, pos_seq):
  7.         sinusoid_inp = torch.ger(pos_seq, self.inv_freq) #序列的位置向量 operation 间隔
  8.         pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim= -1) #正弦余弦
  9.          return pos_emb[:,None,:] #直接返回R,非学习矩阵R

简单把编码维度设置为 10,查询向量也是 10 个,存储之前的序列也是 10,有以下结果:


   
  1. >>>  import torch
  2. >>> inv_freq =  1 / ( 10000 ** (torch.arange( 0.0102.0) /  10))
  3. >>> inv_freq
  4. tensor([ 1.0000e+001.5849e-012.5119e-023.9811e-036.3096e-04])
  5. >>> pos_seq=torch.arange( 20 -1-1-1.0) #qlen+mlen,即 10+ 10的维度然后逆序
  6. >>> pos_seq
  7. tensor([ 19.18.17.16.15.14.13.12.11.10.,   9.,   8.,   7.,   6.,
  8.           5.,   4.,   3.,   2.,   1.,   0.])
  9. >>> sinusoid_inp = torch.ger(pos_seq,inv_freq)
  10. >>> sinusoid_inp
  11. tensor([[ 1.9000e+013.0113e+004.7726e-017.5640e-021.1988e-02],
  12.         [ 1.8000e+012.8528e+004.5214e-017.1659e-021.1357e-02],
  13.         [ 1.7000e+012.6943e+004.2702e-016.7678e-021.0726e-02],
  14.         [ 1.6000e+012.5358e+004.0190e-016.3697e-021.0095e-02],
  15.         [ 1.5000e+012.3773e+003.7678e-015.9716e-029.4644e-03],
  16.         [ 1.4000e+012.2189e+003.5166e-015.5735e-028.8334e-03],
  17.         [ 1.3000e+012.0604e+003.2655e-015.1754e-028.2024e-03],
  18.         [ 1.2000e+011.9019e+003.0143e-014.7773e-027.5715e-03],
  19.         [ 1.1000e+011.7434e+002.7631e-014.3792e-026.9405e-03],
  20.         [ 1.0000e+011.5849e+002.5119e-013.9811e-026.3096e-03],
  21.         [ 9.0000e+001.4264e+002.2607e-013.5830e-025.6786e-03],
  22.         [ 8.0000e+001.2679e+002.0095e-013.1849e-025.0477e-03],
  23.         [ 7.0000e+001.1094e+001.7583e-012.7867e-024.4167e-03],
  24.         [ 6.0000e+009.5094e-011.5071e-012.3886e-023.7857e-03],
  25.         [ 5.0000e+007.9245e-011.2559e-011.9905e-023.1548e-03],
  26.         [ 4.0000e+006.3396e-011.0048e-011.5924e-022.5238e-03],
  27.         [ 3.0000e+004.7547e-017.5357e-021.1943e-021.8929e-03],
  28.         [ 2.0000e+003.1698e-015.0238e-027.9621e-031.2619e-03],
  29.         [ 1.0000e+001.5849e-012.5119e-023.9811e-036.3096e-04],
  30.         [ 0.0000e+000.0000e+000.0000e+000.0000e+000.0000e+00]])
  31. >>> sinusoid_inp.sin()
  32. tensor([[  1.4988e-01,   1.2993e-01,   4.5935e-01,   7.5568e-02,   1.1988e-02],
  33.         [ -7.5099e-01,   2.8479e-01,   4.3689e-01,   7.1598e-02,   1.1357e-02],
  34.         [ -9.6140e-01,   4.3251e-01,   4.1416e-01,   6.7627e-02,   1.0726e-02],
  35.         [ -2.8790e-01,   5.6939e-01,   3.9117e-01,   6.3654e-02,   1.0095e-02],
  36.         [  6.5029e-01,   6.9200e-01,   3.6793e-01,   5.9681e-02,   9.4642e-03],
  37.         [  9.9061e-01,   7.9726e-01,   3.4446e-01,   5.5706e-02,   8.8333e-03],
  38.         [  4.2017e-01,   8.8254e-01,   3.2077e-01,   5.1731e-02,   8.2024e-03],
  39.         [ -5.3657e-01,   9.4569e-01,   2.9688e-01,   4.7755e-02,   7.5714e-03],
  40.         [ -9.9999e-01,   9.8514e-01,   2.7281e-01,   4.3778e-02,   6.9405e-03],
  41.         [ -5.4402e-01,   9.9990e-01,   2.4856e-01,   3.9800e-02,   6.3095e-03],
  42.         [  4.1212e-01,   9.8959e-01,   2.2415e-01,   3.5822e-02,   5.6786e-03],
  43.         [  9.8936e-01,   9.5448e-01,   1.9960e-01,   3.1843e-02,   5.0476e-03],
  44.         [  6.5699e-01,   8.9544e-01,   1.7493e-01,   2.7864e-02,   4.4167e-03],
  45.         [ -2.7942e-01,   8.1396e-01,   1.5014e-01,   2.3884e-02,   3.7857e-03],
  46.         [ -9.5892e-01,   7.1207e-01,   1.2526e-01,   1.9904e-02,   3.1548e-03],
  47.         [ -7.5680e-01,   5.9234e-01,   1.0031e-01,   1.5924e-02,   2.5238e-03],
  48.         [  1.4112e-01,   4.5775e-01,   7.5285e-02,   1.1943e-02,   1.8929e-03],
  49.         [  9.0930e-01,   3.1170e-01,   5.0217e-02,   7.9621e-03,   1.2619e-03],
  50.         [  8.4147e-01,   1.5783e-01,   2.5116e-02,   3.9811e-03,   6.3096e-04],
  51.         [  0.0000e+00,   0.0000e+00,   0.0000e+00,   0.0000e+00,   0.0000e+00]])

使用 Transformer-XL 的预训练模型经典的就是 XLNet 啦,可以更好的处理较长的文本。

然后是关于 Transformer 的复杂度问题进行改进的文章。

Explicit Sparse Transformer

论文标题:

Explicit Sparse Transformer: Concentrated Attention Through Explicit Selection

论文链接:

https://arxiv.org/abs/1912.11637

代码链接:

https://github.com/lancopku/Explicit-Sparse-Transformer

标准 Transformer 的复杂度为 ,但是否序列内的所有元素都有必要被关注到,是否有方法可以简化这个机制?所以本文的“Sparse”重点就体现在只有少量的 token 参与 attention 分布的计算,以提升注意力机制的集中度。

即本来一个词只和少量的词有关,但是标准自注意力却会给所有的词都分配权重然后聚合,一种很自然的想法就是通过显式选择,只让模型关注少数几个元素就行。

模型图如上图,最左边是标准计算注意力的路线,中间的是 Sparse 的实现,可以看到区别就在于中间多了一个手工选择的 Sparsification,最右则是它的执行示意图。简单来说就是在算 softmax 分数之前先进行 top-k 选择出少数的重要元素即可。具体来说先算内积:

然后人工按内积分数过滤 top-k 个元素,即下式中的 M 操作,其他的 P 则直接置为负无穷,这样强制只让 k 个元素被关注。

最后再把分数回乘给 V:

通过这种操作,可以使注意力更集中。这种减轻计算量的操作也让 GPT-3 等模型能够架得更大更暴力也取得了很好的但是玩不起效果。

Longformer

论文标题:

Longformer: The Long-Document Transformer

论文链接:

https://arxiv.org/abs/2004.05150

代码链接:

https://github.com/allenai/longformer


Longformer 也算是一种比较经典的 Sparse 的方法了吧。一共提出了 3 种策略:

  • Sliding Window:如上图 (b) 所示,跟 CNN 很像,给定一个固定的窗口大小 w,其两边都有个 w/2 个 token 与其做 attention。计算复杂度降为 O(n x w),即复杂度与序列长度呈线性关系。而且如果为每一层设置不同的窗口 size 可以很好地平衡模型效率和表示能力。

  • Dilated sliding window:如上图 (c) 所示,类似扩张 CNN,可以在计算复杂度不变的情况下进一步扩大接收域。同样的,如果在多头注意力机制的每个头设置不同的扩张配置,可以关注文章的不同局部上下文,特别是通过这种 Dilated 可以扩张甚至很远的地方。

  • Global Attention :如图 (d) 所示,计算全局 token 可能表征序列的整体特性。比如 BERT 中的 [CLS] 这种功能,复杂度降为 O(n)。

完整内容可以看原文。

Switch Transformer

论文标题:

Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity

论文链接:

https://arxiv.org/abs/2101.03961

代码链接:

https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py

相比起 Sparse Attention 需要用到稀疏算子而很难发挥 GPU、TPU 硬件性能的问题。Switch Transformer 不需要稀疏算子,可以更好的适应 GPU、TPU 等稠密硬件。主要的想法是简化稀疏路由。

即在自然语言 MoE (Mixture of experts)层中,只将 token 表征发送给单个专家而不是多个的表现会更好。模型架构如上图,中间蓝色部分是比价关键的部分,可以看到每次 router 都只把信息传给分数 p 最大的单个 FFN。而这个操作可以大大降低计算量。

然后另一方面,其成功的原因是有一个很秀的并行策略,其同时结合了数据并行 + 模型并行 + expert 并行。具体如下图:

其实对应模型架构来看,experts 并行是算子间的并行,其对应的 FFN 内部有算子级的模型并行,而整个 experts 在计算图上又是个多并行 FFN 分支,这是算子间模型并行,所以能够获得更低的通信开销,提高并行的效率。

Routing Transformers

论文标题:

Efficient Content-Based Sparse Attention with Routing Transformers

收录会议:

TACL 2020

论文链接:

https://arxiv.org/abs/2003.05997

代码链接:

https://github.com/google-research/google-research/tree/master/routing_transformer

和前两篇文章的目标一样,如何使标准 Transformer 的时间复杂度降低。Routing Transformer 将该问题建模为一个路由问题,目的是让模型学会选择词例的稀疏聚类,所谓的聚类簇是关于每个键和查询的内容的函数,而不仅仅与它们的绝对或相对位置相关。

简单来说就是,作用差不多的词是可以变成聚类到一个表示的,这样来加速计算。

如上图与其他模型的对比,图中的每一行代表输出,每一列代表输入,对于 a 和 b 图来说,着色的方块代表每一个输出行注意到的元素。对于路由注意力机制来说,不同的颜色代表输出词例的聚类中的成员。具体做法是先用一种公共的随机权重矩阵对键和查询的值进行投影:

然后把 R 中的向量用 k-means 聚类成 k 个簇,然后在每个簇 中加权求和上下文得到嵌入:

最后作者使用了 个簇,所以时间复杂度降维 。详细可以看原论文和代码实现。

Linformer

论文标题:

Linformer: Self-Attention with Linear Complexity

论文链接:

https://arxiv.org/abs/2006.04768

代码链接:

https://github.com/tatp22/linformer-pytorch

到 O(n)!首先作者从理论和经验上证明了自注意机制所形成的随机矩阵可以近似为低秩矩阵,所以直接多引入线性投影将原始的缩放点积关注分解为多个较小的关注,也就是说这些小关注组合是标准注意力的低秩因数分解。即如上图,在计算键 K 和值 V 时添加两个线性投影矩阵 E 和 F,即:

同时还提供三种层级的参数共享:

  • Headwise: 所有注意力头共享投影句子参数,即 Ei=E,Fi=F。

  • Key-Value: 所有的注意力头的键值映射矩阵共享参数同一参数 ,即 Ei=Fi=E。

  • Layerwise: 所有层参数都共享。即对于所有层,都共享投射矩阵 E。

完整内容可以看原文,原文有理论证明低秩和分析。

Big Bird

论文标题:

Big Bird: Transformers for Longer Sequences

论文链接:

https://arxiv.org/abs/2007.14062

也是采用稀疏注意力机制,将复杂度下降到线性,即 O(N)。如上图,big bird 主要包括三个部分的注意力:

  • Random Attention(随机注意力)。如图 a,对于每一个 token i,随机选择 r 个 token 计算注意力。

  • Window Attention(局部注意力)。如图 b,用滑动窗口表示注意力计算 token 的局部信息。

  • Global Attention(全局注意力)。如图 c,计算全局信息。这些在 Longformer 中也讲过,可以参考对应论文。

最后把这三部分注意力结合在一起得到注意力矩阵 A,如图 d 就是 BIGBIRD 的结果了,计算公式为:

H 是头数,N(i) 是所有需要计算的 token,这里就是由三部分得来的稀疏部分,QKV 则是老伙伴了。

Transformer 变体未入榜番外篇。


Star-Transformer

论文标题:

Star-Transformer

收录会议:

NAACL 2019

论文链接:

https://arxiv.org/abs/1902.09113

代码链接:

https://github.com/fastnlp/fastNLP

问题:

  • Transformer 的自注意力机制每次都要计算所有词之间的注意力,其计算复杂度为输入长度的平方,结构很重

  • 在语言序列中相邻的词往往本身就会有较强的相关性,似乎本来就不需要计算所有词之间

解决:

Star-Transformer 用星型拓扑结构代替了全连通结构如上图左边是 Transformer,而右边是 Star-Transformer。在右边的图中,所有序列中直接相邻的词可以直接相互作用,而非直接相邻的元素则通过中心节点实现间接得信息传递,因此,复杂性从二次降低到线性,同时保留捕获局部成分和长期依赖关系的能力。

  • Radical connections,捕捉非局部信息。即每两个不相邻的卫星节点都是两跳邻居,可以通过两步更新接收非局部信息。

  • Ring connections,捕捉局部信息。由于文本输入是一个序列,相邻词相连以捕捉局部成分之间的关系。值得注意的是它第一个节点和最后一个节点也连接起来,形成环形连接。

具体实现算法如下:

  • 在初始化阶段,卫星节点(周围的词节点)的初始值为各自相应的词向量 ,而中心节点(集成节点)的初始值为所有词节点词向量的平均值

  • 更新卫星节点。对于某卫星节点 ,先得到它的上下文信息 ,它由相邻节点 ,中心节点 ,和这个节点对应的 token 词嵌入 组成。然后多头注意力更新特征,最后使用层归一化。

  • 更新中心节点(relay node)。中心节点与上一时刻和所有卫星信息的交互,所以同样是多头注意力 ,H 是可学习的位置编码(它在所有时刻都是一样的)。

  • 交替更新 T 步,over。

更多阅读

#投 稿 通 道#

 让你的论文被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学习心得技术干货。我们的目的只有一个,让知识真正流动起来。

???? 来稿标准:

• 稿件确系个人原创作品,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向) 

• 如果文章并非首发,请在投稿时提醒并附上所有已发布链接 

• PaperWeekly 默认每篇文章都是首发,均会添加“原创”标志

???? 投稿邮箱:

• 投稿邮箱:hr@paperweekly.site 

• 所有文章配图,请单独在附件中发送 

• 请留下即时联系方式(微信或手机),以便我们在编辑发布时和作者沟通

????

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

关于PaperWeekly

PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。


转载:https://blog.csdn.net/c9Yv2cf9I06K2A9E/article/details/114156792
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场