InstanceNorm && LayerNorm
@author: SUFEHeisenberg
@date: 2023/01/26
先说结论:
- 将Transformer类比于RNN:一个token就是一层layer,对一整句不如token有意义
- 原生Bert代码或huggingface中用的都是
InstanceNorm
instead ofLayerNorm
,但都是torch.nn.LayerNorm
实现的。
1. 对NLP数据的理解
NLP input data的为[batch_size, sequence_len, dim]
表示为[K, N, D]
关键就是这个形如
[K, N, D]
的 tensor 它其实不是一层,而是N 个形如 [K, D] 的层拼接的结果。用 RNN 来想就很明白了,计算完时间步 t 以后才能计算时间步 t+1,比如 h_{t+1}=tanh(Wh_t+b),h_{t+1} 和 h_t 在计算图上的深度都不同,显然第 t 个词的 D 维向量和第 t+1 个词的 D维向量属于两个不同的层。只不过为了方便使用,会把所有的 h_{1:T} 都拼接起来组成一个 tensor 返回。在 xfmr 里,因为各个时间步可以同时计算,所以这一点不够不明显了。
简单来说,由于RNN在每个时间步都共享同一套参数(其实Transformer也是一样,同一层的不同token共享同一套QKV),BatchNorm是跨时间步进行的(换句话说就是跨token进行的,因为一个batch中所有句子在同一位置的token属于同一个时间步),而LayerNorm是只取决于当前时间步(或者说当前这个token)。
从这样的视角来看,或者说从网络的实际计算流程来看,对于一批文本输入
[K, N, D]
,实际上可以看作是由N个[K, d]
的输入拼接而来的,其中每个[K, D]
代表的是一个batch中所有句子在某一位置(或说某个时间步)的token嵌入组成的。在RNN中,[K, D]
按时序依次输入网络,在Transformer中则是通过并行计算,但本质上都是通过同一套参数来计算。 因此,在[K, N, D]
上进行LN,就是在SN个[K, D]
这样的批数据上进行LN,只不过这N个LN共享同一套 gain( γ \gamma γ)和bias( β \beta β)。
在此直接照搬知乎大佬的讲解,通过举例已经非常浅显透彻了。
2. 结合公式举个栗子🌰
2.1 生成demo data
import torch
K, N, D = 2, 3, 4
# 生成demo data
embedding = torch.randn(K, N, D)
Out[1]:
tensor([[[ 2.3833, 0.1780, 1.0667, 0.2227],
[ 0.2482, -0.3889, 0.7117, 0.9091],
[ 0.4513, 1.6905, 0.5648, -1.2175]],
[[ 0.1469, -0.9727, 2.5195, -1.3820],
[-0.0406, 0.4197, 1.8440, 1.2459],
[ 0.0238, 0.4803, -1.0974, -0.3951]]])
2.2 验证LayerNorm
LayerNorm是K*N*D
固定了K,在每一个batch会生成K个mean均值 μ 1 , ⋯ , μ K \mu_1,\cdots,\mu_K μ1,⋯,μK,每个batch中的得到标准差 σ 1 , ⋯ , σ K \sigma_1,\cdots, \sigma_K σ1,⋯,σK。
所以,对于第k个batch ∈ R N × D \in\mathbb{R}^{N\times D} ∈RN×D而言(对于bert而言是 R N + 2 , D \mathbb{R}^{N+2,D} RN+2,D, 在此不细究讨论):
X n d ( k ) ′ = X n d ( k ) − μ k σ k X_{nd}^{(k)\prime} = \frac{X^{(k)}_{nd}-\mu_k}{\sigma_k} Xnd(k)′=σkXnd(k)−μk
μ \mu μ和 σ \sigma σ是每个batch中N*D
的均值方差。
# layer_normalization
layer_norm = torch.nn.LayerNorm([N,dim], elementwise_affine = False)
print("layer_norm: ", layer_norm(embedding))
layer_norm: tensor([[[ 2.0472, -0.4403, 0.5621, -0.3898],
[-0.3610, -1.0797, 0.1617, 0.3843],
[-0.1320, 1.2658, -0.0039, -2.0143]],
[[-0.0760, -1.0675, 2.0254, -1.4301],
[-0.2420, 0.1656, 1.4271, 0.8974],
[-0.1850, 0.2193, -1.1780, -0.5560]]])
验证第一行元素:
mean = embedding.mean(dim=(1,2))
# tensor([0.5683, 0.2327])
std = embedding.std(dim=(1,2), unbiased=False) #一定要记得unbiased=False
# tensor([0.8866, 1.1291])
# or 用Var的数学期望定义
var = torch.square(embedding-mean).mean(dim=(1,2))
#tensor([0.7860, 1.2748])
(embedding[0][0]-mean[0])/std[0]
# Out[189]: tensor([[ 2.0472, -0.4403, 0.5621, -0.3898]])
验证所有元素
eps: float = 0.00001
mean = torch.mean(embedding[:, :, :], dim=(-2,-1), keepdim=True)
var = torch.square(embedding[:, :, :] - mean).mean(dim=(-2,-1), keepdim=True)
print("mean: ", mean.shape)
print("y_custom: ", (embedding[:, :, :] - mean) / torch.sqrt(var + eps))
2.3 验证InstanceNorm
LayerNorm是K*N*D
固定了K*N
,在每一个token会生成K*N
个mean均值 μ 11 , ⋯ , μ k n , ⋯ , μ K N \mu_{11},\cdots,\mu_{kn},\cdots,\mu_{KN} μ11,⋯,μkn,⋯,μKN,每个batch中的得到标准差 σ 1 , ⋯ , σ k n , ⋯ , σ K N \sigma_1,\cdots,\sigma_{}kn,\cdots, \sigma_{KN} σ1,⋯,σkn,⋯,σKN。
所以,对于第kn
个token ∈ R 1 × D \in\mathbb{R}^{1\times D} ∈R1×D而言:
X d ( k n ) ′ = X d ( k n ) − μ k n σ k n X_{d}^{(kn)\prime} = \frac{X^{(kn)}_{d}-\mu_{kn}}{\sigma_{kn}} Xd(kn)′=σknXd(kn)−μkn
μ \mu μ和 σ \sigma σ是每个batch中每个seq_len的token中D
里面的均值方差。
# instance_normalization,可以看到二者其实都是通过nn.LayerNorm实现的
instance_norm = torch.nn.LayerNorm(dim, elementwise_affine = False)
(embedding)
print("instance_norm: ", instance_norm(embedding))
instance_norm: tensor([[[ 1.5902, -0.8784, 0.1164, -0.8283],
[-0.2438, -1.5193, 0.6840, 1.0791],
[ 0.0761, 1.2702, 0.1855, -1.5318]],
[[ 0.0454, -0.6927, 1.6098, -0.9626],
[-1.2464, -0.6145, 1.3411, 0.5199],
[ 0.4668, 1.2533, -1.4650, -0.2550]]])
验证第一行元素:
(embedding[0][0]-embedding[0][0].mean())/embedding[0][0].std(unbiased=False)
# tensor([ 1.5902, -0.8784, 0.1164, -0.8283])
验证所有元素
eps: float = 0.00001
mean = torch.mean(embedding[:, :, :], dim=(-1), keepdim=True)
var = torch.square(embedding[:, :, :] - mean).mean(dim=(-1), keepdim=True)
print("mean: ", mean.shape)
print("y_custom: ", (embedding[:, :, :] - mean) / torch.sqrt(var + eps))
Reference
震惊!BERT用LayerNorm的可能不是你认为的那个Layer Norm?
转载:https://blog.csdn.net/weixin_43557139/article/details/128765886