目录
ASPP结构介绍
ASPP:Atrous Spatial Pyramid Pooling,空洞空间卷积池化金字塔。
简单理解就是个至尊版池化层,其目的与普通的池化层一致,尽可能地去提取特征。
利用主干特征提取网络,会得到一个浅层特征和一个深层特征,这一篇主要以如何对较深层特征进行加强特征提取,也就是在Encoder中所看到的部分。
它就叫做ASPP,主要有5个部分:
- 1x1卷积
- 膨胀率为6的3x3卷积
- 膨胀率为12的3x3卷积
- 膨胀率为18的3x3卷积
- 对输入进去的特征层进行池化
接着会对这五个部分进行一个堆叠,再利用一个1x1卷积对通道数进行调整,获得上图中绿色的特征。
ASPP在代码中的构建
-
import torch
-
import torch.nn
as nn
-
import torch.nn.functional
as F
-
-
class
ASPP(nn.Module):
-
def
__init__(
self, dim_in, dim_out, rate=1, bn_mom=0.1):
-
super(ASPP, self).__init__()
-
self.branch1 = nn.Sequential(
-
nn.Conv2d(dim_in, dim_out, kernel_size=(
1,
1), stride=(
1,
1), padding=
0, dilation=rate, bias=
True),
-
nn.BatchNorm2d(dim_out, momentum=bn_mom),
-
nn.ReLU(inplace=
True),
-
)
-
self.branch2 = nn.Sequential(
-
nn.Conv2d(dim_in, dim_out, kernel_size=(
3,
3), stride=(
1,
1), padding=
6 * rate, dilation=
6 * rate, bias=
True),
-
nn.BatchNorm2d(dim_out, momentum=bn_mom),
-
nn.ReLU(inplace=
True),
-
)
-
self.branch3 = nn.Sequential(
-
nn.Conv2d(dim_in, dim_out, kernel_size=(
3,
3), stride=(
1,
1), padding=
12 * rate, dilation=
12 * rate, bias=
True),
-
nn.BatchNorm2d(dim_out, momentum=bn_mom),
-
nn.ReLU(inplace=
True),
-
)
-
self.branch4 = nn.Sequential(
-
nn.Conv2d(dim_in, dim_out, kernel_size=(
3,
3), stride=(
1,
1), padding=
18 * rate, dilation=
18 * rate, bias=
True),
-
nn.BatchNorm2d(dim_out, momentum=bn_mom),
-
nn.ReLU(inplace=
True),
-
)
-
self.branch5_conv = nn.Conv2d(dim_in, dim_out, kernel_size=(
1,
1), stride=(
1,
1), padding=
0, bias=
True)
-
self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
-
self.branch5_relu = nn.ReLU(inplace=
True)
-
-
self.conv_cat = nn.Sequential(
-
nn.Conv2d(dim_out *
5, dim_out ,kernel_size=(
1,
1), stride=(
1,
1), padding=
0, bias=
True),
-
nn.BatchNorm2d(dim_out, momentum=bn_mom),
-
nn.ReLU(inplace=
True),
-
)
-
-
def
forward(
self, x):
-
[b, c, row, col] = x.size()
-
-
# 五个分支
-
conv1x1 = self.branch1(x)
-
conv3x3_1 = self.branch2(x)
-
conv3x3_2 = self.branch3(x)
-
conv3x3_3 = self.branch4(x)
-
-
# 第五个分支,进行全局平均池化+卷积
-
global_feature = torch.mean(x,
2,
True)
-
global_feature = torch.mean(global_feature,
3,
True)
-
global_feature = self.branch5_conv(global_feature)
-
global_feature = self.branch5_bn(global_feature)
-
global_feature = self.branch5_relu(global_feature)
-
global_feature = F.interpolate(global_feature, (row, col),
None,
'bilinear',
True)
-
-
# 五个分支的内容堆叠起来,然后1x1卷积整合特征。
-
feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=
1)
-
result = self.conv_cat(feature_cat)
-
return result
-
-
-
if __name__ ==
"__main__":
-
model = ASPP(dim_in=
320, dim_out=
256, rate=
16//
16)
-
print(model)
那么从这里来看的话,也是相当清晰的,branch*(1、2、3、4、5)分别代表了ASPP五个部分在def __init__()可以体现,对于每一个都是卷积、标准化、激活函数。
第五个部分可以看到def forward中,首先呢,是要进行一个全局平均池化,再用1x1卷积通道数的整合,标准化、激活函数,接着采用上采样的方法,把它的大小调整成和我们上面获得的分支一样大小的特征层,这样我们才可以将五个部分进行一个堆叠,使用的是torch.cat()函数实现,最后,利用1x1卷积,对输入进来的特征层进行一个通道数的调整,获得想上图中绿色的部分,接着就会将这个具有较高语义信息的有效特征层就会传入到Decoder当中。
参考资料
(6条消息) Pytorch-torchvision源码解读:ASPP_xiongxyowo的博客-CSDN博客_aspp代码
DeepLabV3-/deeplabv3+.pdf at main · Auorui/DeepLabV3- (github.com)
转载:https://blog.csdn.net/m0_62919535/article/details/128876615
查看评论