Pytorch的torch.nn.functional.cross_entropy的ignore_index细解

作用
ignore_index用于忽略ground-truth中某些不需要参与计算的类。假设有两类{0:背景,1:前景},若想在计算交叉熵时忽略背景(0)类,则可令ignore_index=0(同理忽略前景计算可设ignore_index=1)。

代码示例

import torch
import torch.nn.functional as F
pred = torch.Tensor(
    [
        [0.9, 0.1],
        [0.8, 0.2],
        [0.7, 0.3]
    ]
)  # shape=(N,C)=(3,2),N为样本数,C为类数
label = torch.LongTensor([1, 0, 1])  # shape=(N)=(3),3个样本的label分别为1,0,1
out = F.cross_entropy(pred, label, ignore_index=0)  # 忽略0类
print(out)


输出

tensor(1.0421)


验证
pytorch的CrossEntropy使用公式:

计算:
loss=1/2×{[−0.1+ln(e ^{0.9}+e ^{0.1} )]+[−0.3+ln(e ^{0.7}+e ^{0.3})]}= 1/2×(1.1711+0.9130)=1.0421 ​

ignore_index表示计算交叉熵时,自动忽略的标签值,example:

import torch
import torch.nn.functional as F
pred = []
pred.append([0.9, 0.1])
pred.append([0.8, 0.2])
pred = torch.Tensor(pred).view(-1,  2)

label = torch.LongTensor([[1], [-1]])  # 这里输出类别为0或1,-1表示不参与计算loss。且计算平均loss的时候,reduction只计算实际参与计算的个数,这里相当于batchsize=2,但其中第index=1行为-1不参与计算loss。

# out = F.cross_entropy(pred.view(-1, 2), label.view(-1, )) 
out = F.cross_entropy(pred.view(-1, 2), label.view(-1, ), ignore_index=-1) 
print(out)

输出结果:

tensor(1.1711)

再比如:

例如我的pred是(b,2,w,h),而label索引是(b,1,w,h)的矩阵,其中只有0,1值,0值代表从pred的第0个通道选择像素值,1值代表从pred的第1个通道选择像素值。

而此时我发现因为程序的错误,label矩阵中混入了一些-1值,这样正常的话是会报错的,因为pred矩阵没有-1通道。此时最简单的一个方法就是

loss = nn.CrossEntropyLoss(ignore_index=-1) 

上述操作就是相当于忽略-1标签值为-1的位置的对应像素值就不参与计算梯度了

torch.nn.CrossEntropyLoss 同理。

相关推荐

  1. Linux 技术深潜:top命令全方位使用教程

    2024-05-14 06:10:12       19 阅读
  2. css做一条很分割线

    2024-05-14 06:10:12       41 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-05-14 06:10:12       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-05-14 06:10:12       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-05-14 06:10:12       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-05-14 06:10:12       18 阅读

热门阅读

  1. MASK-RCNN自定义数据集优化思路(pytorch)

    2024-05-14 06:10:12       9 阅读
  2. ffmpeg

    2024-05-14 06:10:12       10 阅读
  3. Vue2 实现前端分页

    2024-05-14 06:10:12       8 阅读
  4. Element-UI快速入门

    2024-05-14 06:10:12       11 阅读
  5. 人大金仓参数查看和设置

    2024-05-14 06:10:12       10 阅读
  6. 记录解决问题--redis ssl连接

    2024-05-14 06:10:12       8 阅读
  7. MySQL中的多表设计

    2024-05-14 06:10:12       8 阅读
  8. 【PyTorch】torch.distributed()的含义和使用方法

    2024-05-14 06:10:12       9 阅读