飞道的博客

损失函数:DIOU loss手写实现

553人阅读  评论(0)

下面是纯diou代码


  
  1. '''
  2. 计算两个box的中心点距离d
  3. '''
  4. # d = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
  5. d = math.sqrt((pred[:, - 1] - target[:, - 1]) ** 2 + (pred[:, - 2] - target[:, - 2]) ** 2)
  6. # 左边x
  7. pred_l = pred[:, - 1] - pred[:, - 1] / 2
  8. target_l = target[:, - 1] - target[:, - 1] / 2
  9. # 上边y
  10. pred_t = pred[:, - 2] - pred[:, - 2] / 2
  11. target_t = target[:, - 2] - target[:, - 2] / 2
  12. # 右边x
  13. pred_r = pred[:, - 1] + pred[:, - 1] / 2
  14. target_r = target[:, - 1] + target[:, - 1] / 2
  15. # 下边y
  16. pred_b = pred[:, - 2] + pred[:, - 2] / 2
  17. target_b = target[:, - 2] + target[:, - 2] / 2
  18. '''
  19. 计算两个box的bound的对角线距离
  20. '''
  21. bound_l = torch. min(pred_l, target_l) # left
  22. bound_r = torch. max(pred_r, target_r) # right
  23. bound_t = torch. min(pred_t, target_t) # top
  24. bound_b = torch. max(pred_b, target_b) # bottom
  25. c = math.sqrt((bound_r - bound_l) ** 2 + (bound_b - bound_t) ** 2)
  26. dloss = iou - (d ** 2) / (c ** 2)
  27. loss = 1 - dloss.clamp( min=- 1.0, max= 1.0)

第一步 计算两个box的中心点距离d

首先要知道pred和target的输出结果是什么
pred[:,:2]第一个:表示多个图片,第二个:2表示前两个数值,代表矩形框中心点(Y,X)
pred[:,2:]第一个:表示多个图片,第二个2:表示两个数值,代表矩形框长宽(H,W)
target[:,:2]同理,
d =
 

根据上面的分析来计算左右上下坐标lrtb

 然后计算内部2个矩形的最小外接矩形的对角线长度c

 d是两个预测矩形中心点的距离

 下面接受各种极端情况
A 两个框中心对齐时候,d/c=0,iou可能0-1

 A 两个框相距很远时,d/c=1,iou=0

 所以d/c属于0-1
dloss=iou-d/c属于-1到1
因此设置loss=1-dloss属于0-2

 

展示iou\giou\diou代码,这是YOLOX自带的损失函数,其中dloss是我自己写的
YOLOX是下载自
GitHub - Megvii-BaseDetection/YOLOX: YOLOX is a high-performance anchor-free YOLO, exceeding yolov3~v5 with MegEngine, ONNX, TensorRT, ncnn, and OpenVINO supported. Documentation: https://yolox.readthedocs.io/YOLOX is a high-performance anchor-free YOLO, exceeding yolov3~v5 with MegEngine, ONNX, TensorRT, ncnn, and OpenVINO supported. Documentation: https://yolox.readthedocs.io/ - GitHub - Megvii-BaseDetection/YOLOX: YOLOX is a high-performance anchor-free YOLO, exceeding yolov3~v5 with MegEngine, ONNX, TensorRT, ncnn, and OpenVINO supported. Documentation: https://yolox.readthedocs.io/https://github.com/Megvii-BaseDetection/YOLOX


  
  1. class IOUloss(nn.Module):
  2. def __init__( self, reduction="none", loss_type="iou"):
  3. super(IOUloss, self).__init__()
  4. self.reduction = reduction
  5. self.loss_type = loss_type
  6. def forward( self, pred, target):
  7. assert pred.shape[ 0] == target.shape[ 0]
  8. pred = pred.view(- 1, 4)
  9. target = target.view(- 1, 4)
  10. tl = torch. max(
  11. (pred[:, : 2] - pred[:, 2:] / 2), (target[:, : 2] - target[:, 2:] / 2)
  12. )
  13. # pred target都是[H,W,Y,X]
  14. # (Y,X)-(H,W) 左上角
  15. br = torch. min(
  16. (pred[:, : 2] + pred[:, 2:] / 2), (target[:, : 2] + target[:, 2:] / 2)
  17. )
  18. # (X,Y)+(H,W) 右下角
  19. area_p = torch.prod(pred[:, 2:], 1) # HxW
  20. area_g = torch.prod(target[:, 2:], 1)
  21. en = (tl < br). type(tl. type()).prod(dim= 1)
  22. area_i = torch.prod(br - tl, 1) * en
  23. area_u = area_p + area_g - area_i
  24. iou = (area_i) / (area_u + 1e-16)
  25. if self.loss_type == "iou":
  26. loss = 1 - iou ** 2
  27. elif self.loss_type == "giou":
  28. c_tl = torch. min(
  29. (pred[:, : 2] - pred[:, 2:] / 2), (target[:, : 2] - target[:, 2:] / 2)
  30. )
  31. c_br = torch. max(
  32. (pred[:, : 2] + pred[:, 2:] / 2), (target[:, : 2] + target[:, 2:] / 2)
  33. )
  34. area_c = torch.prod(c_br - c_tl, 1)
  35. giou = iou - (area_c - area_u) / area_c.clamp( 1e-16)
  36. loss = 1 - giou.clamp( min=- 1.0, max= 1.0)
  37. # pred[:, :2] pred[:, 2:]
  38. # (Y,X) (H,W)
  39. # target[:, :2] target[:, 2:]
  40. # (Y,X) (H,W)
  41. elif self.loss_type == "diou":
  42. '''
  43. 计算两个box的中心点距离d
  44. '''
  45. # d = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
  46. d = math.sqrt((pred[:, - 1] - target[:, - 1]) ** 2 + (pred[:, - 2] - target[:, - 2]) ** 2)
  47. # 左边x
  48. pred_l = pred[:, - 1] - pred[:, - 1] / 2
  49. target_l = target[:, - 1] - target[:, - 1] / 2
  50. # 上边y
  51. pred_t = pred[:, - 2] - pred[:, - 2] / 2
  52. target_t = target[:, - 2] - target[:, - 2] / 2
  53. # 右边x
  54. pred_r = pred[:, - 1] + pred[:, - 1] / 2
  55. target_r = target[:, - 1] + target[:, - 1] / 2
  56. # 下边y
  57. pred_b = pred[:, - 2] + pred[:, - 2] / 2
  58. target_b = target[:, - 2] + target[:, - 2] / 2
  59. '''
  60. 计算两个box的bound的对角线距离
  61. '''
  62. bound_l = torch. min(pred_l, target_l) # left
  63. bound_r = torch. max(pred_r, target_r) # right
  64. bound_t = torch. min(pred_t, target_t) # top
  65. bound_b = torch. max(pred_b, target_b) # bottom
  66. c = math.sqrt((bound_r - bound_l) ** 2 + (bound_b - bound_t) ** 2)
  67. dloss = iou - (d ** 2) / (c ** 2)
  68. loss = 1 - dloss.clamp( min=- 1.0, max= 1.0)
  69. # Step1
  70. # def DIoU(a, b):
  71. # d = a.center_distance(b)
  72. # c = a.bound_diagonal_distance(b)
  73. # return IoU(a, b) - (d ** 2) / (c ** 2)
  74. # Step2-1
  75. # def center_distance(self, other):
  76. # '''
  77. # 计算两个box的中心点距离
  78. # '''
  79. # return euclidean_distance(self.center, other.center)
  80. # Step2-2
  81. # def euclidean_distance(p1, p2):
  82. # '''
  83. # 计算两个点的欧式距离
  84. # '''
  85. # x1, y1 = p1
  86. # x2, y2 = p2
  87. # return math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
  88. # Step3
  89. # def bound_diagonal_distance(self, other):
  90. # '''
  91. # 计算两个box的bound的对角线距离
  92. # '''
  93. # bound = self.boundof(other)
  94. # return euclidean_distance((bound.x, bound.y), (bound.r, bound.b))
  95. # Step3-2
  96. # def boundof(self, other):
  97. # '''
  98. # 计算box和other的边缘外包框,使得2个box都在框内的最小矩形
  99. # '''
  100. # xmin = min(self.x, other.x)
  101. # ymin = min(self.y, other.y)
  102. # xmax = max(self.r, other.r)
  103. # ymax = max(self.b, other.b)
  104. # return BBox(xmin, ymin, xmax, ymax)
  105. # Step3-3
  106. # def euclidean_distance(p1, p2):
  107. # '''
  108. # 计算两个点的欧式距离
  109. # '''
  110. # x1, y1 = p1
  111. # x2, y2 = p2
  112. # return math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
  113. if self.reduction == "mean":
  114. loss = loss.mean()
  115. elif self.reduction == "sum":
  116. loss = loss. sum()
  117. return loss

GitHub - Megvii-BaseDetection/YOLOX: YOLOX is a high-performance anchor-free YOLO, exceeding yolov3~v5 with MegEngine, ONNX, TensorRT, ncnn, and OpenVINO supported. Documentation: https://yolox.readthedocs.io/


转载:https://blog.csdn.net/zjc910997316/article/details/125500138
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场