摘要
尽管 GAN 已经能够生成非常逼真的高分辨率图片了, 但是要确保生成的图片和文本语义的一致还是一个很有挑战性的问题. 为了解决这个问题, 作者提出了 MirroGAN (a novel global-local attentive and semantic-preserving text-to-image-to-text framework) 这个模型. 这个模型主要由三个部分组成:
-
文本语义嵌入模块 (STEM, semantic text embedding module):
STEM 模块主要是为了生成单词和句子级别的嵌入语义.
-
全局和局部的注意力协同模块 (GLAM, global-local collaborative attentive module):
在这个模块主要是使用 STEM 模块生成的嵌入语义作为全局和局部的注意力逐步生成语义一致且逼真的图片.
-
语义文本再生和对齐模块 (STREAM, semantic text regeneration and alignment module):
在这个模块会使用一个 RNN 网络对生成的图片重新生成描述, 然后与原来的文本进行语义上的对齐.
模型的结构示意如下:
1. 注意力机制的应用
之前的文本到图像任务都是只通过一个判别器去判断文本和生成的图片对是否逼真以及图片语义是否和文本语义一致, 但是由于文本和图像之间的语义鸿沟, 单纯的依赖判别器是很难判断的而且是很不高效. 最近注意力模块开始被用来解决这个问题, 在 AttnGAN 中就使用了单词层次的注意力. 但是作者认为仅仅使用单词层次的注意力是不够的, 例如在 CUB 和 COCO 数据集中同一个语义的描述分别有 10 个和 5 个. 作者认为对于多阶段训练的生成器, 语义平滑是很重要的, 一次全局的句子级别的注意力也是应该要考虑的.
1.1 STEM 模块
在 MirroGAN 中 STEM 模块就是为了提取描述的单词特征和句子特征的模块. 作者使用了 RNN 提取这些特征:
w,s=RNN(T)
w,s=RNN(T)
其中
T={Tl∣l=0,⋯,L}
T={Tl∣l=0,⋯,L},
L
L 表示句子的长度.
w={wl∣l=0,⋯,L}∈RD×L
w={wl∣l=0,⋯,L}∈RD×L 表示单词级别的特征,
s∈RD
s∈RD 表示句子级别特征. 由于不同的单词的排序可能表示相同的语义. 因此, 为了提高模型的鲁棒性, 这里作者还使用了StackGAN 中提出的 conditioning augmentation method, 从而产生更多的图像-文本对数据, 增强对条件文本流形上的小扰动的鲁棒性.
sca=Fca(s)
sca=Fca(s)
其中
sca∈RD′
sca∈RD′,
D′
D′ 是增强后的维度.
2. 文本到图像与图像到文本
虽然从图像到文本是另外一个任务, 但是这两个任务都需要在两个域 (文本域和图像域)对齐语义. 所以作者就想到构建一个包含这两个任务的模型, 在这个模型下就能够使用对称的约束. 下图展示了这种约束:
2.1 GLAM 模块
GLAM 模块是一个级联的生成网络. 借鉴了 AttnGAN 的结构:
f0=F0(z,sca)
f0=F0(z,sca)
fi=Fi(fi−1,Fatti(fi−1,w,sca)),i∈{1,2,⋯,m−1}
fi=Fi(fi−1,Fatti(fi−1,w,sca)),i∈{1,2,⋯,m−1}
Ii=Gi(fi),i∈{1,2,⋯,m−1}
Ii=Gi(fi),i∈{1,2,⋯,m−1}
其中
Fatti
Fatti 就是全局-局部注意力协同模块, 包含了两个部分
Attwi−1
Atti−1w 和
Attsi−1
Atti−1s,
Fatti(fi−1,w,sca)=concat(Attwi−1,Attsi−1)
Fatti(fi−1,w,sca)=concat(Atti−1w,Atti−1s), 其中
Attwi−1=∑L−1l=0(Ui−1wl)(softmax(fTi−1(Ui−1wl)))T
Atti−1w=l=0∑L−1(Ui−1wl)(softmax(fi−1T(Ui−1wl)))T
其中
Ui−1∈RMi−1×D
Ui−1∈RMi−1×D 是一个视觉感知层, 计算完之后
Attwi−1
Atti−1w 和
fi−1
fi−1 有相同的大小.
Attsi−1=(Vi−1sca)∘(softmax(fi−1∘(Vi−1sca)))
Atti−1s=(Vi−1sca)∘(softmax(fi−1∘(Vi−1sca)))
其中
∘
∘ 表示逐元素相乘
Vi−1
Vi−1 也是一个视觉感知层, 计算完之后会和
Attwi−1
Atti−1w 进行拼接.
2.2 STREAM 模块
STREAM 模块是从生成的图像得到文本描述, 作者使用了一个应用广泛的 Encoder-Decoder 结构的框架. 编码器是一个在 ImageNet 上预训练的 CNN 网络, 解码器是一个 RNN 网络. 最后一个生成器生成的图片
Im−1
Im−1 会被送到这个网络中.
x−1=CNN(Im−1)
x−1=CNN(Im−1)
xt=WeTt,t∈{0,⋯,L−1}
xt=WeTt,t∈{0,⋯,L−1}
pt+1=RNN(xt),t∈{0,⋯,L−1}
pt+1=RNN(xt),t∈{0,⋯,L−1}
x−1∈RMm−1
x−1∈RMm−1 是送入 RNN 的图像特征,
We∈RMm−1×D
We∈RMm−1×D 表示词嵌入向量.
3. 目标函数
为了能够端到端的训练模型, 作者还是使用了两个判别损失: 一个是判别图像是否真实的损失, 另一个是判别图像和文本对是否一致的判别损失. 然后针对重新生成的文本, 作者基于交叉熵损失设计了一个文本重建损失.
生成器的目标函数如下:
LGi=−12EIi∼pIi[log(Di(Ii))]−12EIi∼pIi[log(Di(Ii,s))]
LGi=−21EIi∼pIi[log(Di(Ii))]−21EIi∼pIi[log(Di(Ii,s))]
STREAM 模块的损失函数为:
Lstream=−∑L−1t=0logpt(Tt)
Lstream=−t=0∑L−1logpt(Tt)
这个损失函数还是 STREAM 网络预训练时的损失函数.
所以总的损失函数为:
LG=∑m−1i=0LGi+λLstream
LG=i=0∑m−1LGi+λLstream
判别器的损失函数为:
LDi=−12EIGTi∼pGTIi[log(Di(IGTi))]−12EIi∼pIi[log(1−Di(Ii))]−12EIGTi∼pGTIi[log(Di(IGTi,s))]−12EIi∼pIi[log(1−Di(Ii,s))]
LDi=−21EIiGT∼pIiGT[log(Di(IiGT))]−21EIi∼pIi[log(1−Di(Ii))]−21EIiGT∼pIiGT[log(Di(IiGT,s))]−21EIi∼pIi[log(1−Di(Ii,s))]
所以:
LD=∑m−1i=0LDi
LD=i=0∑m−1LDi
4. 实验
4.1 对比实验
作者在 COCO 和 CUB 数据集上面做了对比实验:
然后作者还专门做了和 AttnGAN 的对比实验:
然后作者还做实验验证全局注意力(GA, global attention)的作用以及
λ
λ 的影响:
转载:
https://blog.csdn.net/zh20166666/article/details/105737158