飞道的博客

pytorch几种损失函数CrossEntropyLoss、NLLLoss、BCELoss、BCEWithLogitsLoss、focal_loss、heatmap_loss

959人阅读  评论(0)

分类问题常用的几种损失,记录下来备忘,后续不断完善。

nn.CrossEntropyLoss()交叉熵损失

常用于多分类问题

CE = nn.CrossEntropyLoss()
loss = CE(input,target)

Input: (N, C) , dtype: float, N是样本数量,在批次计算时通常就是batch_size
target: (N), dtype: long,是类别号,0 ≤ targets[i] ≤ C−1
pytorch中的交叉熵损失就是softmax和NLL损失的组合,即

nn.CrossEntropyLoss()(input,target) == nn.NLLLoss()(torch.log(nn.Softmax()(input)),target)

nn.NLLLoss()

NLL = nn.NLLLoss()
loss = NLL(input,target)

Input: (N, C) , dtype: float, N是样本数量,在批次计算时通常就是batch_size
target: (N), dtype: long,是类别号,0 ≤ targets[i] ≤ C−1

nn.BCELoss() 二元交叉熵损失

常用于二分类或多标签分类

BCE = nn.BCELoss()
loss = BCE(input,target)

Input: (N, x) , dtype: float, N是样本数量,在批次计算时通常就是batch_size,x是标签数
target: (N, x), dtype: float,通常是标签的独热码形式,注意需改成float格式

nn.BCEWithLogitsLoss()

相当于BCE加上sigmoid

nn.BCEWithLogitsLoss()(input,target) == nn.BCELoss()(torch.sigmoid(input),target)

focal_loss

focal loss在pytorch中没有,它常用在目标检测问题中,公式和曲线见论文中的图:

带平衡参数的focal loss公式如下:

代码:(待后补)

heatmap_loss

heatmap_loss出现在anchor-free的目标检测网络centernet和conernet中,它在focal loss的基础上进一步改进,加入了对热点区域的损失减小的措施,以使模型输出可以较容易的收敛到检测点附件区域。(否则,必须收敛到检测点的话,难度太大,收敛速度慢)

注意,它只是在otherwise情况下多加了一个 ( 1 − Y x y c ) β (1-Y_{xyc})^\beta (1Yxyc)β 除此之外,就是focal loss


转载:https://blog.csdn.net/Brikie/article/details/116171023
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场