飞道的博客

DeepLabV3+:ASPP加强特征提取网络的搭建

587人阅读  评论(0)

目录

ASPP结构介绍

ASPP在代码中的构建

参考资料


ASPP结构介绍

ASPP:Atrous Spatial Pyramid Pooling,空洞空间卷积池化金字塔。
简单理解就是个至尊版池化层,其目的与普通的池化层一致,尽可能地去提取特征。

利用主干特征提取网络,会得到一个浅层特征和一个深层特征,这一篇主要以如何对较深层特征进行加强特征提取,也就是在Encoder中所看到的部分。

它就叫做ASPP,主要有5个部分:

  • 1x1卷积
  • 膨胀率为6的3x3卷积
  • 膨胀率为12的3x3卷积
  • 膨胀率为18的3x3卷积
  • 对输入进去的特征层进行池化

接着会对这五个部分进行一个堆叠,再利用一个1x1卷积对通道数进行调整,获得上图中绿色的特征。

ASPP在代码中的构建


  
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class ASPP(nn.Module):
  5. def __init__( self, dim_in, dim_out, rate=1, bn_mom=0.1):
  6. super(ASPP, self).__init__()
  7. self.branch1 = nn.Sequential(
  8. nn.Conv2d(dim_in, dim_out, kernel_size=( 1, 1), stride=( 1, 1), padding= 0, dilation=rate, bias= True),
  9. nn.BatchNorm2d(dim_out, momentum=bn_mom),
  10. nn.ReLU(inplace= True),
  11. )
  12. self.branch2 = nn.Sequential(
  13. nn.Conv2d(dim_in, dim_out, kernel_size=( 3, 3), stride=( 1, 1), padding= 6 * rate, dilation= 6 * rate, bias= True),
  14. nn.BatchNorm2d(dim_out, momentum=bn_mom),
  15. nn.ReLU(inplace= True),
  16. )
  17. self.branch3 = nn.Sequential(
  18. nn.Conv2d(dim_in, dim_out, kernel_size=( 3, 3), stride=( 1, 1), padding= 12 * rate, dilation= 12 * rate, bias= True),
  19. nn.BatchNorm2d(dim_out, momentum=bn_mom),
  20. nn.ReLU(inplace= True),
  21. )
  22. self.branch4 = nn.Sequential(
  23. nn.Conv2d(dim_in, dim_out, kernel_size=( 3, 3), stride=( 1, 1), padding= 18 * rate, dilation= 18 * rate, bias= True),
  24. nn.BatchNorm2d(dim_out, momentum=bn_mom),
  25. nn.ReLU(inplace= True),
  26. )
  27. self.branch5_conv = nn.Conv2d(dim_in, dim_out, kernel_size=( 1, 1), stride=( 1, 1), padding= 0, bias= True)
  28. self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
  29. self.branch5_relu = nn.ReLU(inplace= True)
  30. self.conv_cat = nn.Sequential(
  31. nn.Conv2d(dim_out * 5, dim_out ,kernel_size=( 1, 1), stride=( 1, 1), padding= 0, bias= True),
  32. nn.BatchNorm2d(dim_out, momentum=bn_mom),
  33. nn.ReLU(inplace= True),
  34. )
  35. def forward( self, x):
  36. [b, c, row, col] = x.size()
  37. # 五个分支
  38. conv1x1 = self.branch1(x)
  39. conv3x3_1 = self.branch2(x)
  40. conv3x3_2 = self.branch3(x)
  41. conv3x3_3 = self.branch4(x)
  42. # 第五个分支,进行全局平均池化+卷积
  43. global_feature = torch.mean(x, 2, True)
  44. global_feature = torch.mean(global_feature, 3, True)
  45. global_feature = self.branch5_conv(global_feature)
  46. global_feature = self.branch5_bn(global_feature)
  47. global_feature = self.branch5_relu(global_feature)
  48. global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)
  49. # 五个分支的内容堆叠起来,然后1x1卷积整合特征。
  50. feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim= 1)
  51. result = self.conv_cat(feature_cat)
  52. return result
  53. if __name__ == "__main__":
  54. model = ASPP(dim_in= 320, dim_out= 256, rate= 16// 16)
  55. print(model)

那么从这里来看的话,也是相当清晰的,branch*(1、2、3、4、5)分别代表了ASPP五个部分在def __init__()可以体现,对于每一个都是卷积、标准化、激活函数。

第五个部分可以看到def forward中,首先呢,是要进行一个全局平均池化,再用1x1卷积通道数的整合,标准化、激活函数,接着采用上采样的方法,把它的大小调整成和我们上面获得的分支一样大小的特征层,这样我们才可以将五个部分进行一个堆叠,使用的是torch.cat()函数实现,最后,利用1x1卷积,对输入进来的特征层进行一个通道数的调整,获得想上图中绿色的部分,接着就会将这个具有较高语义信息的有效特征层就会传入到Decoder当中。

参考资料

(6条消息) Pytorch-torchvision源码解读:ASPP_xiongxyowo的博客-CSDN博客_aspp代码

DeepLabV3-/deeplabv3+.pdf at main · Auorui/DeepLabV3- (github.com)


转载:https://blog.csdn.net/m0_62919535/article/details/128876615
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场