飞道的博客

YOLOv5小目标切图检测

762人阅读  评论(0)

当我们在检测较大分辨率的图片时,对小目标的检测效果一直是较差的,所以就有了下面几种方法:

  1. 将图片压缩成大尺寸进行训练( 想法:没显存,搞不来)
  2. 添加小检测头(想法:P5模型还有点用,P6模型完全没用)
  3. 添加一些检测模型和玄学机制(想法:你要是写论文就去看看知*吧,只需要在最后面加一句:已达到工业检测要求)
  4. 切图检测(想法:比较耗时,过程也比较繁琐,可以尝试)

切图检测

思路:

  1. 将原图切成你想要的数量
  2. 将切成的小图进行训练,得到模型
  3. 将你需要检测的图片切成小图,用模型检测,并得到每张图目标位置的信息,保存在对应图片的txt文件
  4. 将所有txt文件融合,得到1个txt文件,并在原图上显示

一:切块


  
  1. # -*- coding:utf-8 -*-
  2. import os
  3. import matplotlib.pyplot as plt
  4. import cv2
  5. import numpy as np
  6. def divide_img( img_path, img_name, save_path):
  7. imgg = img_path + img_name
  8. img = cv2.imread(imgg)
  9. # img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
  10. h = img.shape[ 0]
  11. w = img.shape[ 1]
  12. n = int(np.floor(h * 1.0 / 1000)) + 1
  13. m = int(np.floor(w * 1.0 / 1000)) + 1
  14. print( 'h={},w={},n={},m={}'. format(h, w, n, m))
  15. dis_h = int(np.floor(h / n))
  16. dis_w = int(np.floor(w / m))
  17. num = 0
  18. for i in range(n):
  19. for j in range(m):
  20. num += 1
  21. print( 'i,j={}{}'. format(i, j))
  22. sub = img[dis_h * i:dis_h * (i + 1), dis_w * j:dis_w * (j + 1), :]
  23. cv2.imwrite(save_path + '{}_{}.bmp'. format(name, num), sub)
  24. if __name__ == '__main__':
  25. img_path = r'G:\1/'
  26. save_path = r'G:\3/'
  27. img_list = os.listdir(img_path)
  28. for name in img_list:
  29. divide_img(img_path, name, save_path)

 

 使用模型检测后得到:

二:融合txt文件


  
  1. import os
  2. from cv2 import cv2
  3. # 保存所有图片的宽高
  4. # todo: img_info={'name': [w_h, child_w_h, mix_row_w_h, mix_col_w_h]}
  5. img_info = {}
  6. all_info = {}
  7. # 初始化img_info
  8. def init( big_images_path, mix_percent, rows, cols):
  9. image_names = os.listdir(big_images_path)
  10. for img_name in image_names:
  11. big_path = big_images_path + '\\' + img_name
  12. # print(big_path)
  13. img = cv2.imread(big_path)
  14. size = img.shape[ 0: 2]
  15. w = size[ 1]
  16. h = size[ 0]
  17. child_width = int(w) // cols
  18. child_height = int(h) // rows
  19. mix_row_width = int(child_width * mix_percent * 2)
  20. mix_row_height = child_height
  21. mix_col_width = child_width
  22. mix_col_height = int(child_height * mix_percent * 2)
  23. # 根据img保存w和h
  24. img_info[img_name.split( '.')[ 0]] = [w, h, child_width, child_height, mix_row_width, mix_row_height,
  25. mix_col_width, mix_col_height]
  26. # 读取所有检测出来的 小图片的label
  27. def get_label_info( labels_path, mix_percent, rows, cols):
  28. labels = os.listdir(labels_path)
  29. for label in labels:
  30. # print(label)
  31. # todo: type: 0正常, 1row, 2col
  32. # 判断该label属于哪一张图片
  33. cur_label_belong = label.split( '_')[ 0]
  34. cur_big_width = img_info[cur_label_belong][ 0]
  35. cur_big_height = img_info[cur_label_belong][ 1]
  36. # 融合区域距离边界的一小部分宽高
  37. cur_row_width_step = img_info[cur_label_belong][ 2] * ( 1 - mix_percent)
  38. cur_col_height_step = img_info[cur_label_belong][ 3] * ( 1 - mix_percent)
  39. # 文件名给予数据
  40. # child_type = []
  41. # child_num = []
  42. # label内容给予数据
  43. child_class_index = []
  44. child_x = []
  45. child_y = []
  46. child_width = []
  47. child_height = []
  48. type = - 1
  49. num = - 1
  50. class_index = - 1
  51. x = 0.0
  52. y = 0.0
  53. width = 0.0
  54. height = 0.0
  55. # print(f'{label}')
  56. # 读取所有需要的数据
  57. f = open(labels_path + '\\' + label, 'r')
  58. lines = f.read()
  59. # print(lines)
  60. f.close()
  61. contents = lines.split( '\n')[:- 1]
  62. # print(contents)
  63. for content in contents:
  64. content = content.split( ' ')
  65. # print(content)
  66. class_index = int(content[ 0])
  67. x = float(content[ 1])
  68. y = float(content[ 2])
  69. width = float(content[ 3])
  70. height = float(content[ 4])
  71. pass
  72. # print(class_index, x, y, width, height)
  73. assert class_index != - 1 or x != - 1.0 or y != - 1.0 or width != - 1.0 or height != - 1.0, \
  74. f'class_index:{class_index}, x:{x}, y:{y}, width:{width}, height:{height}'
  75. # 转换成 数据 坐标, 并根据不同的num进行处理
  76. num = label.split( '_')[- 1].split( '.')[ 0] # 图片尾号 命名: xxxx_x.jpg xxxx_mix_row_xx.jpg xxxx_mix_col_xx.jpg
  77. cur_img_width = 0
  78. cur_img_height = 0
  79. distance_x = 0
  80. distance_y = 0
  81. small_image_width = img_info[cur_label_belong][ 2]
  82. small_image_height = img_info[cur_label_belong][ 3]
  83. if label.find( 'mix_row') != - 1:
  84. # type = 1.
  85. distance_x = int(num) % (cols- 1)
  86. distance_y = int(num) // (rows- 1)
  87. cur_img_width = img_info[cur_label_belong][ 4]
  88. cur_img_height = img_info[cur_label_belong][ 5]
  89. # row x 加上step
  90. x = x * cur_img_width + cur_row_width_step + distance_x * small_image_width
  91. y = y * cur_img_height + distance_y * cur_img_height
  92. elif label.find( 'mix_col') != - 1:
  93. # type = 2
  94. distance_x = int(num) % cols
  95. distance_y = int(num) // rows
  96. cur_img_width = img_info[cur_label_belong][ 6]
  97. cur_img_height = img_info[cur_label_belong][ 7]
  98. # col y 加上step
  99. print( f'x:{x}, y:{y}, cur_img_width:{cur_img_width}, cur_img_height:{cur_img_height}')
  100. x = x * cur_img_width + distance_x * cur_img_width
  101. y = y * cur_img_height + cur_col_height_step + distance_y * small_image_height
  102. print( f'x:{x}, y:{y}, height:{cur_col_height_step}')
  103. else:
  104. # type = 0
  105. distance_x = int(num) % cols
  106. distance_y = int(num) // rows
  107. cur_img_width = img_info[cur_label_belong][ 2]
  108. cur_img_height = img_info[cur_label_belong][ 3]
  109. # 小图片内, 无需加上 step
  110. x = x * cur_img_width + distance_x * cur_img_width
  111. y = y * cur_img_height + distance_y * cur_img_height
  112. assert cur_img_width != 0 or cur_img_height != 0 or distance_x != 0 or distance_y != 0, \
  113. f'cur_img_width:{cur_img_width}, cur_img_height:{cur_img_height}, distance_x:{distance_x}, distance_y:{distance_y}'
  114. assert x < cur_big_width and y < cur_big_height, f'{label}, {content}\nw:{cur_big_width}, h:{cur_big_height}, x:{x}, y:{y}'
  115. width = width * cur_img_width
  116. height = height * cur_img_height
  117. assert x != 0.0 or y != 0.0 or width != 0.0 or height != 0.0, f'x:{x}, y:{y}, width:{width}, height:{height}'
  118. # child_type.append(type)
  119. # child_num.append(num)
  120. child_class_index.append(class_index)
  121. child_x.append(x)
  122. child_y.append(y)
  123. child_width.append(width)
  124. child_height.append(height)
  125. # todo: 所有信息 根据 cur_label_belong 存储在all_info中
  126. for index, x, y, width, height in zip(child_class_index, child_x, child_y, child_width, child_height):
  127. if cur_label_belong not in all_info:
  128. all_info[cur_label_belong] = [[index, x, y, width, height]]
  129. else:
  130. all_info[cur_label_belong].append([index, x, y, width, height])
  131. child_class_index.clear()
  132. child_x.clear()
  133. child_y.clear()
  134. child_width.clear()
  135. child_height.clear()
  136. # print((all_info['0342']))
  137. # todo: 转成 yolo 格式, 保存
  138. def save_yolo_label( yolo_labels_path):
  139. for key in all_info:
  140. # img_path = r'G:\Unity\code_project\other_project\data\joint\big_images' + '\\' + key + '.JPG'
  141. # img = cv2.imread(img_path)
  142. yolo_label_path = yolo_labels_path + '\\' + key + '.txt'
  143. cur_big_width = img_info[key][ 0]
  144. cur_big_height = img_info[key][ 1]
  145. content = ''
  146. i = 0
  147. for index, x, y, width, height in all_info[key]:
  148. # print(all_info[key][i])
  149. x = x / cur_big_width
  150. y = y / cur_big_height
  151. width = width / cur_big_width
  152. height = height / cur_big_height
  153. assert x < 1.0 and y < 1.0 and width < 1.0 and height < 1.0, f'{key} {i}\n{all_info[key][i]}\nx:{x}, y:{y}, width:{width}, height:{height}'
  154. content += f'{index} {x} {y} {width} {height}\n'
  155. i += 1
  156. with open(yolo_label_path, 'w') as f:
  157. f.write(content)
  158. def joint_main( big_images_path=r'G:\3',
  159. labels_path=r'G:\5',
  160. yolo_labels_path=r'G:\6',
  161. mix_percent=0.2,
  162. rows=4,
  163. cols=4):
  164. print( f'融合图片, 原图片路径:{big_images_path}\n小图检测的txt结果路径:{labels_path}\n数据融合后txt结果路径:{yolo_labels_path}')
  165. init(big_images_path, mix_percent, rows, cols)
  166. get_label_info(labels_path, mix_percent, rows, cols)
  167. save_yolo_label(yolo_labels_path)
  168. joint_main()

三:原图显示


  
  1. # -*- coding: utf-8 -*-
  2. import os
  3. from PIL import Image
  4. from PIL import ImageDraw, ImageFont
  5. from cv2 import cv2
  6. def draw_images( images_dir, txt_dir, box_dir, font_type_path):
  7. font = ImageFont.truetype(font_type_path, 50)
  8. if not os.path.exists(box_dir):
  9. os.makedirs(box_dir)
  10. # num = 0
  11. # 设置颜色
  12. all_colors = [ 'red', 'green', 'yellow', 'blue', 'pink', 'black', 'skyblue', 'brown', 'orange', 'purple', 'gray',
  13. 'lightpink', 'gold', 'brown', 'black']
  14. colors = {}
  15. for file in os.listdir(txt_dir):
  16. print(file)
  17. image = os.path.splitext(file)[ 0].replace( 'xml', 'bmp') + '.bmp'
  18. # 转换成cv2读取,防止图片载入错误
  19. img = cv2.imread(images_dir + '/' + image)
  20. TURN = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  21. img = Image.fromarray(TURN)
  22. # img.show()
  23. if img.mode == "P":
  24. img = img.convert( 'RGB')
  25. w, h = img.size
  26. tag_path = txt_dir + '/' + file
  27. with open(tag_path) as f:
  28. for line in f:
  29. line_parts = line.split( ' ')
  30. # 根据不同的 label 保存颜色
  31. if line_parts[ 0] not in colors.keys():
  32. colors[line_parts[ 0]] = all_colors[ len(colors.keys())]
  33. color = colors[line_parts[ 0]]
  34. draw = ImageDraw.Draw(img)
  35. x = ( float(line_parts[ 1]) - 0.5 * float(line_parts[ 3])) * w
  36. y = ( float(line_parts[ 2]) - 0.5 * float(line_parts[ 4])) * h
  37. xx = ( float(line_parts[ 1]) + 0.5 * float(line_parts[ 3])) * w
  38. yy = ( float(line_parts[ 2]) + 0.5 * float(line_parts[ 4])) * h
  39. draw.rectangle([x - 10, y - 10, xx, yy], fill= None, outline=color, width= 5)
  40. # num += 1
  41. del draw
  42. img.save(box_dir + '/' + image)
  43. # print(file, num)
  44. # print(colors)
  45. def draw_main( box_dir=r'G:\5',
  46. txt_dir=r'G:\6',
  47. image_source_dir=r'G:\3'):
  48. font_type_path = 'C:/Windows/Fonts/simsun.ttc'
  49. print( f'标注框, 数据来源: {txt_dir}\n 被标注图片: {image_source_dir}\n 结果保存路径: {box_dir}')
  50. draw_images(image_source_dir, txt_dir, box_dir, font_type_path)
  51. draw_main()

 

效果对比:(左YOLOv5检测,右YOLOv5+切图检测)

 

参考:

https://blog.csdn.net/qq_43622870/article/details/124984295?ops_request_misc=&request_id=&biz_id=102&utm_term=yolov5%E5%B0%8F%E7%9B%AE%E6%A0%87%E6%A3%80%E6%B5%8B&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduweb~default-0-124984295.142^v68^control,201^v4^add_ask,213^v2^t3_control2&spm=1018.2226.3001.4187


转载:https://blog.csdn.net/qq_58355216/article/details/128318604
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场