飞道的博客

【论文笔记】图像修复Learning Joint Spatial-Temporal Transformations for Video Inpainting

393人阅读  评论(0)

论文地址:https://arxiv.org/abs/2007.10247

源码地址:GitHub - researchmm/STTN: [ECCV'2020] STTN: Learning Joint Spatial-Temporal Transformations for Video Inpainting

一、项目介绍

        当下SOTA的方法大多采用注意模型,通过搜索参考帧中缺失的内容来完成一帧,并进一步逐帧完成整个视频。然而,这些方法在空间和时间维度上的注意结果可能会不一致,这往往会导致视频中的模糊和时间伪影。

        本文提出时空转换网络STTN(Spatial-Temporal Transformer Network)。具体来说,是通过自注意机制同时填补所有输入帧中的缺失区域,并提出通过时空对抗性损失来优化STTN。为了展示该模型的优越性,我们使用标准的静止掩模和更真实的运动物体掩模进行了定量和定性的评价。

二、STTN

         模型输入是图像帧序列和masks序列,图像帧序列经过Encoder、Mask经过scale变化成原来的1/4,然后一起送入Spatial-Temporal Transformer模块;Spatial-Temporal Transformer模块由8个TransformerBlock组成;最后Decoder模块负责将特征还原成图像帧序列。STTN的整体结构图如下:

图1

1.Encoder

        Frame-Level Encoder帧级编码器,通过叠加二维卷积层来构建的,目的是为每一帧的低级别像素的深度特征,就是四个卷积层提取单帧图像特征,要素不多,结构图如下:

图2

代码如下:


  
  1. # 位置model/sttn.py
  2. self.encoder = nn.Sequential(
  3. nn.Conv2d( 3, 64, kernel_size= 3, stride= 2, padding= 1),
  4. nn.LeakyReLU( 0.2, inplace= True),
  5. nn.Conv2d( 64, 64, kernel_size= 3, stride= 1, padding= 1),
  6. nn.LeakyReLU( 0.2, inplace= True),
  7. nn.Conv2d( 64, 128, kernel_size= 3, stride= 2, padding= 1),
  8. nn.LeakyReLU( 0.2, inplace= True),
  9. nn.Conv2d( 128, channel, kernel_size= 3, stride= 1, padding= 1),
  10. nn.LeakyReLU( 0.2, inplace= True),
  11. )

2.Spatial-Temporal Transformer Network

           这是STTN的核心部分,通过一个多头 patch-based attention模块沿着空间和时间维度进行搜索。transformer的不同头部计算不同尺度上对空间patch的注意力。这样的设计允许我们处理由复杂的运动引起的外观变化。例如,对大尺寸的patch(例如,帧大小H×W)旨在修复固定的背景;对小尺寸的patch(如H/10×W/10)有助于在视频的任意位置捕捉移动的前景信息。

(1)TranformerBlock

        TransformerBlock由Embedding、MatchingAttending组成,代码中MatchingAttending被放在一起合成了MultiHeadedAttention。输入是帧序列特征和masks。

        帧序列的特征平分成四部分,每个部分经过Embedding映射为四种尺度的Key、Query、Value,从而对应不同尺度的patch。masks经过变换也变成四个尺度。将四个尺度的Key、Query、Value和四个尺度masks分别送入MultiHeadedAttention,然后将结果Concat到一起,经过FeedForward层进一步分特征融合,得到融合了时间维度上不同尺度空间patch的特征。结构图如下:

 图3

代码如下:


  
  1. # 位置model/sttn.py
  2. class TransformerBlock(nn.Module):
  3. """
  4. Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
  5. """
  6. def __init__( self, patchsize, hidden=128):
  7. super().__init__()
  8. self.attention = MultiHeadedAttention(patchsize, d_model=hidden)
  9. self.feed_forward = FeedForward(hidden)
  10. def forward( self, x):
  11. x, m, b, c = x[ 'x'], x[ 'm'], x[ 'b'], x[ 'c']
  12. x = x + self.attention(x, m, b, c)
  13. x = x + self.feed_forward(x)
  14. return { 'x': x, 'm': m, 'b': b, 'c': c}

(2)KQV Formatting

        图3中的KQV Formatting结构如下图:

图4

        TranformerBlock输入的帧序列特征,被平分成四个部分,每个部分经过变换,变成四种尺度patch的特征。

        代码如下:


  
  1. # 位置model/sttn.py
  2. query = query.view(b, t, d_k, out_h, height, out_w, width)
  3. query = query.permute( 0, 1, 3, 5, 2, 4, 6).contiguous().view(
  4. b, t*out_h*out_w, d_k*height*width)
  5. key = key.view(b, t, d_k, out_h, height, out_w, width)
  6. key = key.permute( 0, 1, 3, 5, 2, 4, 6).contiguous().view(
  7. b, t*out_h*out_w, d_k*height*width)
  8. value = value.view(b, t, d_k, out_h, height, out_w, width)
  9. value = value.permute( 0, 1, 3, 5, 2, 4, 6).contiguous().view(
  10. b, t*out_h*out_w, d_k*height*width)

(3)Mask Formatting

        KQV Formatting将帧序列变成四种尺度,masks也需要对应的变成四种尺度,结构如下:

 图5

代码如下:


  
  1. # 位置model/sttn.py
  2. mm = m.view(b, t, 1, out_h, height, out_w, width)
  3. mm = mm.permute( 0, 1, 3, 5, 2, 4, 6).contiguous().view(
  4. b, t*out_h*out_w, height*width)
  5. mm = (mm.mean(- 1) > 0.5).unsqueeze( 1).repeat( 1, t*out_h*out_w, 1)

(4)Attention

        图3中的Attention层其实包括了论文中的Matching和Attending,结构图如下:

 图6

        图6中的K*Q/sqrt(Q.size(-1))是在计算各个patch的相似性,对应论文中公式,第i个斑块与第j个patch的相似性记为::

         图6中的masked_fill(Mask, -1e9)是将图像中的损坏部分mask掉,意思是只学习图像中完整的部分,坏的就不要学习了。

         论文中的Attention对应图6中的matmul,负责计算相关patches的value加权和得到输出patch的query。公式如下:

代码如下:


  
  1. # 位置model/sttn.py
  2. class Attention(nn.Module):
  3. """
  4. Compute 'Scaled Dot Product Attention
  5. """
  6. def forward( self, query, key, value, m):
  7. scores = torch.matmul(query, key.transpose(- 2, - 1)) / math.sqrt(query.size(- 1))
  8. scores.masked_fill(m, - 1e9)
  9. p_attn = F.softmax(scores, dim=- 1)
  10. p_val = torch.matmul(p_attn, value)
  11. return p_val, p_attn

3.Decoder

         frame-level decoder: 帧级解码器,把特征解码成帧。期间特征图经过了两次的膨胀,中间穿插几个2d卷积,整体过程有点像Encoder倒过来,结构图如下:

 图7

代码如下:


  
  1. # 位置model/sttn.py
  2. self.decoder = nn.Sequential(
  3. deconv(channel, 128, kernel_size= 3, padding= 1),
  4. nn.LeakyReLU( 0.2, inplace= True),
  5. nn.Conv2d( 128, 64, kernel_size= 3, stride= 1, padding= 1),
  6. nn.LeakyReLU( 0.2, inplace= True),
  7. deconv( 64, 64, kernel_size= 3, padding= 1),
  8. nn.LeakyReLU( 0.2, inplace= True),
  9. nn.Conv2d( 64, 3, kernel_size= 3, stride= 1, padding= 1)
  10. )

三、损失函数

        本文使用GAN来对模型进行优化,G模型选择了一个像素级的重建损失即L1Loss,D网络使用T-PatchGAN来优化。

1.G模型损失函数

        G模型图像破坏区域的L1Loss:

        G模型图像有效区域的L1Loss:

        

        STTN的对抗性损失: ​​​

        上式看上去很复杂,其实就是将恢复的图像送入D模型,然后送入损失函数(可选nsgan、lsgan、hinge)

        总结上面三个式子,得出G模型的损失函数,其中三个权重官方推荐

2.D网络的损失函数

        对抗性的损失在提高视频绘制的感知质量和时空一致性方面显示出了良好的效果。公式如下:

         看山去还是很复杂,其实就是将原图和复原图分别送入损失函数(可选nsgan、lsgan、hinge),然后求和,代码中是取均值,不过应该影响不大。

三、训练流程

        下面是我根据官方代码梳理的整个训练过程:

        1.从数据集选取数据,同时为选取的数据随机带有破坏图案的masks

        2.根据masks将原图的破坏部分变成0,得到masked_frame

        3.将masked_frame和masks送入G模型(生成模型,即STTN),得出估计pred_img

        4.根据pred_img修复图像,得到comp_img

        5.将原图和comp_img分别送入D模型,分别得到输出的特征 real_vid_feat和fake_vid_feat

        6.使用real_vid_feat和fake_vid_feat对D模型进行优化(损失函数可选nsgan、lsgan、hinge)

        7.使用原图、comp_img和gen_vid_feat对G模型进行优化(L1Loss)

代码如下:


  
  1. # 位置core/trainer.py
  2. def _train_epoch( self, pbar):
  3. device = self.config[ 'device']
  4. for frames, masks in self.train_loader:
  5. self.adjust_learning_rate()
  6. self.iteration += 1
  7. frames, masks = frames.to(device), masks.to(device)
  8. b, t, c, h, w = frames.size()
  9. masked_frame = (frames * ( 1 - masks). float())
  10. # 将masked_frame和masks送入G模型(生成模型,即STTN),得出估计pred_img
  11. pred_img = self.netG(masked_frame, masks)
  12. frames = frames.view(b*t, c, h, w)
  13. masks = masks.view(b*t, 1, h, w)
  14. # 根据pred_img修复图像,得到comp_img
  15. comp_img = frames*( 1.-masks) + masks*pred_img
  16. gen_loss = 0
  17. dis_loss = 0
  18. # 将原图和comp_img分别送入D模型,分别得到输出的特征 real_vid_feat和fake_vid_feat
  19. real_vid_feat = self.netD(frames)
  20. fake_vid_feat = self.netD(comp_img.detach())
  21. # 计算D网络的损失
  22. dis_real_loss = self.adversarial_loss(real_vid_feat, True, True)
  23. dis_fake_loss = self.adversarial_loss(fake_vid_feat, False, True)
  24. dis_loss += (dis_real_loss + dis_fake_loss) / 2
  25. self.add_summary(
  26. self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item())
  27. self.add_summary(
  28. self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item())
  29. self.optimD.zero_grad()
  30. dis_loss.backward()
  31. # 使用real_vid_feat和fake_vid_feat对D模型进行优化
  32. self.optimD.step()
  33. # G模型的对抗性损失
  34. gen_vid_feat = self.netD(comp_img)
  35. gan_loss = self.adversarial_loss(gen_vid_feat, True, False)
  36. gan_loss = gan_loss * self.config[ 'losses'][ 'adversarial_weight']
  37. gen_loss += gan_loss
  38. self.add_summary(
  39. self.gen_writer, 'loss/gan_loss', gan_loss.item())
  40. # G模型图像破坏区域的L1Loss
  41. hole_loss = self.l1_loss(pred_img*masks, frames*masks)
  42. hole_loss = hole_loss / torch.mean(masks) * self.config[ 'losses'][ 'hole_weight']
  43. gen_loss += hole_loss
  44. self.add_summary(
  45. self.gen_writer, 'loss/hole_loss', hole_loss.item())
  46. # G模型图像有效区域的L1Loss
  47. valid_loss = self.l1_loss(pred_img*( 1-masks), frames*( 1-masks))
  48. valid_loss = valid_loss / torch.mean( 1-masks) * self.config[ 'losses'][ 'valid_weight']
  49. gen_loss += valid_loss
  50. self.add_summary(
  51. self.gen_writer, 'loss/valid_loss', valid_loss.item())
  52. self.optimG.zero_grad()
  53. gen_loss.backward()
  54. # 使用原图、comp_img和gen_vid_feat对G模型进行优化
  55. self.optimG.step()
  56. # 日志
  57. if self.config[ 'global_rank'] == 0:
  58. pbar.update( 1)
  59. pbar.set_description((
  60. f"d: {dis_loss.item():.3f}; g: {gan_loss.item():.3f};"
  61. f"hole: {hole_loss.item():.3f}; valid: {valid_loss.item():.3f}")
  62. )
  63. # saving models
  64. if self.iteration % self.train_args[ 'save_freq'] == 0:
  65. self.save( int(self.iteration//self.train_args[ 'save_freq']))
  66. if self.iteration > self.train_args[ 'iterations']:
  67. break

        接下来代码中有些重点,需要简单说明一下:

1.准备数据集

        项目中用到Davisyoutube-vos数据集,两个数据集其实都是为segmentation任务设计的,代码中都只使用图像数据,不使用标注数据。我们以davis数据集为例,davis数据集由90个视频组成,每个视频已经拆帧成图片,数据集下载完每个视频一个文件夹,但是程序需要每个视频这图片打成zip文件,下面的程序可以用来完成这个工作:


  
  1. import os
  2. import zipfile
  3. def zipDir( dirpath, out_full_name):
  4. zipname = zipfile.ZipFile(out_full_name, 'w', zipfile.ZIP_DEFLATED)
  5. for path, dirnames, filenames in os.walk(dirpath):
  6. fpath= path.replace(dirpath, '')
  7. for filename in filenames:
  8. zipname.write(os.path.join(path, filename), os.path.join(fpath, filename))
  9. zipname.close()
  10. if __name__== "__main__":
  11. org_dir = r'datasets/davis/JPEGImages_org'
  12. zip_dir = r'datasets/davis/JPEGImages'
  13. g = os.walk(org_dir)
  14. for path, dir_list, file_list in g:
  15. for dir_name in dir_list:
  16. input_path = os.path.join(path, dir_name)
  17. output_path = os.path.join(zip_dir, dir_name+ '.zip')
  18. print(input_path, '\n', output_path)
  19. zipDir(input_path, output_path)

2.数据选取策略

        数据是从90个视频中随机挑一个,然后在这个视频中选取sample_length张图片,最终每个视频都会选取一个图片组,在论文中提到有两种数据选取策略,就是下面这个公式:

         其中代表以t为中心n为半径的连续帧序列,代码实现是50%概率用一个长度为sample_length的框随机滑动选取;表示从以s采样率的视频中均匀采样的远处帧,代码中并未使用这种方式,而是50%概率随机选取帧,这样也许是为了解决缓解数据不够多的问题。

        选图片组的代码如下:


  
  1. # 位置:core/dataset.py
  2. def get_ref_index( length, sample_length):
  3. # 50%概率随机选取帧
  4. if random.uniform( 0, 1) > 0.5:
  5. ref_index = random.sample( range(length), sample_length)
  6. ref_index.sort()
  7. else:
  8. # 50%概率用一个长度为sample_length的框随机滑动选取
  9. pivot = random.randint( 0, length-sample_length)
  10. ref_index = [pivot+i for i in range(sample_length)]
  11. return ref_index

3.生成随机masks

        有了图片组,还需要为每个图片组随机生成masks。其中0代表背景,1代表破坏部分。代码如下,注释已经很清楚:


  
  1. # 位置:core/utils.py
  2. def create_random_shape_with_random_motion( video_length, imageHeight=240, imageWidth=432):
  3. # 生成的破坏图案宽高占原图的1/3到100%
  4. height = random.randint(imageHeight// 3, imageHeight- 1)
  5. width = random.randint(imageWidth// 3, imageWidth- 1)
  6. # 生成不规则的破坏图案
  7. edge_num = random.randint( 6, 8)
  8. ratio = random.randint( 6, 8)/ 10
  9. region = get_random_shape(
  10. edge_num=edge_num, ratio=ratio, height=height, width=width)
  11. region_width, region_height = region.size
  12. # 随机放置破坏图案
  13. x, y = random.randint(
  14. 0, imageHeight-region_height), random.randint( 0, imageWidth-region_width)
  15. velocity = get_random_velocity(max_speed= 3)
  16. m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8))
  17. m.paste(region, (y, x, y+region.size[ 0], x+region.size[ 1]))
  18. masks = [m.convert( 'L')]
  19. # 50%概率所有的mask一样
  20. if random.uniform( 0, 1) > 0.5:
  21. return masks*video_length
  22. # 50%概率mask中的破坏图案会移动
  23. for _ in range(video_length- 1):
  24. x, y, velocity = random_move_control_points(
  25. x, y, imageHeight, imageWidth, velocity, region.size, maxLineAcceleration=( 3, 0.5), maxInitSpeed= 3)
  26. m = Image.fromarray(
  27. np.zeros((imageHeight, imageWidth)).astype(np.uint8))
  28. m.paste(region, (y, x, y+region.size[ 0], x+region.size[ 1]))
  29. masks.append(m.convert( 'L'))
  30. return masks


转载:https://blog.csdn.net/xian0710830114/article/details/129037940
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场