小言_互联网的博客

目标检测算法——YOLOv5/YOLOv7改进之结合CBAM

446人阅读  评论(0)

关注”PandaCVer“公众号

深度学习Tricks,第一时间送达

目录

(一)前沿介绍

(二)YOLOv5/YOLOv7改进之结合CBAM

1.配置common.py文件

2.配置yolo.py文件

3.配置yolov7_CBAM.yaml文件


(一)前沿介绍

论文题目:《CBAM: Convolutional Block Attention Module》
论文地址:  https://arxiv.org/pdf/1807.06521.pdf

CBAM注意力结构基本原理:从上图明显可以看到, CBAM一共包含2个独立的子模块, 通道注意力模块(Channel Attention Module,CAM) 和空间注意力模块(Spartial Attention Module,SAM) ,分别进行通道与空间维度上的注意力特征融合。 这样不只能够节约参数和计算力,并且保证了其能够做为即插即用的模块集成到现有的网络架构中去。实验证明,将CBAM注意力模块嵌入到YOLOv7网络中,有利于解决原始网络无注意力偏好的问题。

(二)YOLOv5/YOLOv7改进之结合CBAM

改进方法基本和YOLOv5一样,分三步走:

1.配置common.py文件

加入CBAM代码


  
  1. class ChannelAttention(nn.Module):
  2. def __init__( self, in_planes, ratio=16):
  3. super(ChannelAttention, self).__init__()
  4. self.avg_pool = nn.AdaptiveAvgPool2d( 1)
  5. self.max_pool = nn.AdaptiveMaxPool2d( 1)
  6. self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias= False)
  7. self.relu = nn.ReLU()
  8. self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias= False)
  9. self.sigmoid = nn.Sigmoid()
  10. def forward( self, x):
  11. avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
  12. max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
  13. out = self.sigmoid(avg_out + max_out)
  14. return out
  15. class SpatialAttention(nn.Module):
  16. def __init__( self, kernel_size=7):
  17. super(SpatialAttention, self).__init__()
  18. assert kernel_size in ( 3, 7), 'kernel size must be 3 or 7'
  19. padding = 3 if kernel_size == 7 else 1
  20. self.conv = nn.Conv2d( 2, 1, kernel_size, padding=padding, bias= False)
  21. self.sigmoid = nn.Sigmoid()
  22. def forward( self, x):
  23. avg_out = torch.mean(x, dim= 1, keepdim= True)
  24. max_out, _ = torch. max(x, dim= 1, keepdim= True)
  25. x = torch.cat([avg_out, max_out], dim= 1)
  26. x = self.conv(x)
  27. return self.sigmoid(x)
  28. class CBAM(nn.Module):
  29. # Standard convolution
  30. def __init__( self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  31. super(CBAM, self).__init__()
  32. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias= False)
  33. self.bn = nn.BatchNorm2d(c2)
  34. self.act = nn.Hardswish() if act else nn.Identity()
  35. self.ca = ChannelAttention(c2)
  36. self.sa = SpatialAttention()

2.配置yolo.py文件

加入CBAM模块

3.配置yolov7_CBAM.yaml文件

添加方法灵活多变,Backbone或者Neck都可,例子如下:


  
  1. # anchors
  2. anchors:
  3. - [12,16, 19,36, 40,28] # P3/ 8
  4. - [36,75, 76,55, 72,146] # P4/ 16
  5. - [142,110, 192,243, 459,401] # P5/ 32
  6. # yolov7 backbone
  7. backbone:
  8. # [from, number, module, args]
  9. [[-1, 1, Conv, [32, 3, 1]], # 0
  10. [-1, 1, CBAM, [64, 3, 2]], # 1 -P1/ 2
  11. [-1, 1, Conv, [64, 3, 1]],
  12. [-1, 1, CBAM, [128, 3, 2]], # 3 -P2/ 4
  13. [-1, 1, Conv, [64, 1, 1]],
  14. [-2, 1, Conv, [64, 1, 1]],
  15. [-1, 1, Conv, [64, 3, 1]],
  16. [-1, 1, Conv, [64, 3, 1]],
  17. [-1, 1, Conv, [64, 3, 1]],
  18. [-1, 1, Conv, [64, 3, 1]],
  19. [[-1, -3, -5, -6], 1, Concat, [1]],
  20. [-1, 1, Conv, [256, 1, 1]], # 11
  21. [-1, 1, MP, []],
  22. [-1, 1, Conv, [128, 1, 1]],
  23. [-3, 1, Conv, [128, 1, 1]],
  24. [-1, 1, Conv, [128, 3, 2]],
  25. [[-1, -3], 1, Concat, [1]], # 16 -P3/ 8
  26. [-1, 1, Conv, [128, 1, 1]],
  27. [-2, 1, Conv, [128, 1, 1]],
  28. [-1, 1, Conv, [128, 3, 1]],
  29. [-1, 1, Conv, [128, 3, 1]],
  30. [-1, 1, Conv, [128, 3, 1]],
  31. [-1, 1, Conv, [128, 3, 1]],
  32. [[-1, -3, -5, -6], 1, Concat, [1]],
  33. [-1, 1, Conv, [512, 1, 1]], # 24
  34. [-1, 1, MP, []],
  35. [-1, 1, Conv, [256, 1, 1]],
  36. [-3, 1, Conv, [256, 1, 1]],
  37. [-1, 1, Conv, [256, 3, 2]],
  38. [[-1, -3], 1, Concat, [1]], # 29 -P4/ 16
  39. [-1, 1, Conv, [256, 1, 1]],
  40. [-2, 1, Conv, [256, 1, 1]],
  41. [-1, 1, Conv, [256, 3, 1]],
  42. [-1, 1, Conv, [256, 3, 1]],
  43. [-1, 1, Conv, [256, 3, 1]],
  44. [-1, 1, Conv, [256, 3, 1]],
  45. [[-1, -3, -5, -6], 1, Concat, [1]],
  46. [-1, 1, Conv, [1024, 1, 1]], # 37
  47. [-1, 1, MP, []],
  48. [-1, 1, Conv, [512, 1, 1]],
  49. [-3, 1, Conv, [512, 1, 1]],
  50. [-1, 1, Conv, [512, 3, 2]],
  51. [[-1, -3], 1, Concat, [1]], # 42 -P5/ 32
  52. [-1, 1, Conv, [256, 1, 1]],
  53. [-2, 1, Conv, [256, 1, 1]],
  54. [-1, 1, Conv, [256, 3, 1]],
  55. [-1, 1, Conv, [256, 3, 1]],
  56. [-1, 1, Conv, [256, 3, 1]],
  57. [-1, 1, Conv, [256, 3, 1]],
  58. [[-1, -3, -5, -6], 1, Concat, [1]],
  59. [-1, 1, Conv, [1024, 1, 1]], # 50
  60. ]
  61. # yolov7 head
  62. head:
  63. [[-1, 1, SPPCSPC, [512]], # 51
  64. [-1, 1, Conv, [256, 1, 1]],
  65. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  66. [37, 1, Conv, [256, 1, 1]], # route backbone P4
  67. [[-1, -2], 1, Concat, [1]],
  68. [-1, 1, Conv, [256, 1, 1]],
  69. [-2, 1, Conv, [256, 1, 1]],
  70. [-1, 1, Conv, [128, 3, 1]],
  71. [-1, 1, Conv, [128, 3, 1]],
  72. [-1, 1, Conv, [128, 3, 1]],
  73. [-1, 1, Conv, [128, 3, 1]],
  74. [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
  75. [-1, 1, Conv, [256, 1, 1]], # 63
  76. [-1, 1, Conv, [128, 1, 1]],
  77. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  78. [24, 1, Conv, [128, 1, 1]], # route backbone P3
  79. [[-1, -2], 1, Concat, [1]],
  80. [-1, 1, Conv, [128, 1, 1]],
  81. [-2, 1, Conv, [128, 1, 1]],
  82. [-1, 1, Conv, [64, 3, 1]],
  83. [-1, 1, Conv, [64, 3, 1]],
  84. [-1, 1, Conv, [64, 3, 1]],
  85. [-1, 1, Conv, [64, 3, 1]],
  86. [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
  87. [-1, 1, Conv, [128, 1, 1]], # 75
  88. [-1, 1, MP, []],
  89. [-1, 1, Conv, [128, 1, 1]],
  90. [-3, 1, Conv, [128, 1, 1]],
  91. [-1, 1, Conv, [128, 3, 2]],
  92. [[-1, -3, 63], 1, Concat, [1]],
  93. [-1, 1, Conv, [256, 1, 1]],
  94. [-2, 1, Conv, [256, 1, 1]],
  95. [-1, 1, Conv, [128, 3, 1]],
  96. [-1, 1, Conv, [128, 3, 1]],
  97. [-1, 1, Conv, [128, 3, 1]],
  98. [-1, 1, Conv, [128, 3, 1]],
  99. [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
  100. [-1, 1, Conv, [256, 1, 1]], # 88
  101. [-1, 1, MP, []],
  102. [-1, 1, Conv, [256, 1, 1]],
  103. [-3, 1, Conv, [256, 1, 1]],
  104. [-1, 1, Conv, [256, 3, 2]],
  105. [[-1, -3, 51], 1, Concat, [1]],
  106. [-1, 1, Conv, [512, 1, 1]],
  107. [-2, 1, Conv, [512, 1, 1]],
  108. [-1, 1, Conv, [256, 3, 1]],
  109. [-1, 1, Conv, [256, 3, 1]],
  110. [-1, 1, Conv, [256, 3, 1]],
  111. [-1, 1, Conv, [256, 3, 1]],
  112. [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
  113. [-1, 1, Conv, [512, 1, 1]], # 101
  114. [75, 1, RepConv, [256, 3, 1]],
  115. [88, 1, RepConv, [512, 3, 1]],
  116. [101, 1, RepConv, [1024, 3, 1]],
  117. [[102,103,104], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
  118. ]

关于YOLOv5/YOLOv7改进可关注并留言博主的CSDN

>>>>>>一起交流!互相学习!共同进步!<<<<<<


转载:https://blog.csdn.net/m0_53578855/article/details/127585358
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场