神经网络的损失函数——nn.CrossEntropyLoss()

1.参数

loss_func_none = nn.CrossEntropyLoss(reduction="none")
loss_func_mean = nn.CrossEntropyLoss(reduction="mean")
loss_func_sum = nn.CrossEntropyLoss(reduction="sum")

默认是“mean”  也就是说当loss_func_none = nn.CrossEntropyLoss()时 会输出一组batch 的损失平均值

import torch
import torch.nn as nn
loss_func = nn.CrossEntropyLoss(reduction="none")

pre = torch.tensor([[0.8, 0.5, 0.2, 0.5],
                         [0.2, 0.9, 0.3, 0.2],
                         [0.4, 0.3, 0.7, 0.1],
                         [0.1, 0.2, 0.4, 0.8]], dtype=torch.float)
tgt_index = torch.tensor([0,1,2,3], dtype=torch.long)
print(loss_func(pre, tgt_index))

输出如下

import torch
import torch.nn as nn
loss_func = nn.CrossEntropyLoss()

pre = torch.tensor([[0.8, 0.5, 0.2, 0.5],
                         [0.2, 0.9, 0.3, 0.2],
                         [0.4, 0.3, 0.7, 0.1],
                         [0.1, 0.2, 0.4, 0.8]], dtype=torch.float)
tgt_index = torch.tensor([0,1,2,3], dtype=torch.long)
print(loss_func(pre, tgt_index))

输出

tgt表示样本类别的真实值,有两种表示形式,一种是类别的index,另一种是one-hot形式。

tgt_index_data = torch.tensor([0,
                               1,
                               2,
                               3], dtype=torch.long)
tgt_onehot_data = torch.tensor([[1, 0, 0, 0],
                                [0, 1, 0, 0],
                                [0, 0, 1, 0],
                                [0, 0, 0, 1]], dtype=torch.float)

 损失函数|交叉熵损失函数 (zhihu.com)

 

 2.计算过程

loss_func = nn.CrossEntropyLoss()
pre = torch.tensor([0.8, 0.5, 0.2, 0.5], dtype=torch.float)
tgt = torch.tensor([1, 0, 0, 0], dtype=torch.float)
print("手动计算:")
print("1.softmax")
print(torch.softmax(pre, dim=-1))
print("2.取对数")
print(torch.log(torch.softmax(pre, dim=-1)))
print("3.与真实值相乘")
print(-torch.sum(torch.mul(torch.log(torch.softmax(pre, dim=-1)), tgt), dim=-1))
print()
print("调用损失函数:")
print(loss_func(pre, tgt))

 

 交叉熵损失函数会自动对输入模型的预测值进行softmax。因此在多分类问题中,如果使用nn.CrossEntropyLoss(),则预测模型的输出层无需添加softmax层。

参考torch.nn.CrossEntropyLoss() 参数、计算过程以及及输入Tensor形状 - 知乎 (zhihu.com)

相关推荐

  1. pytorch-10 神经网络损失函数

    2024-04-10 06:14:05       44 阅读
  2. 神经网络损失函数(上)——回归任务

    2024-04-10 06:14:05       45 阅读
  3. 神经网络损失函数(下)——分类任务

    2024-04-10 06:14:05       45 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-04-10 06:14:05       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-04-10 06:14:05       100 阅读
  3. 在Django里面运行非项目文件

    2024-04-10 06:14:05       82 阅读
  4. Python语言-面向对象

    2024-04-10 06:14:05       91 阅读

热门阅读

  1. Android Camera API 1打开相机失败

    2024-04-10 06:14:05       32 阅读
  2. Hadoop简介

    2024-04-10 06:14:05       35 阅读
  3. 数据仓库理论与实战

    2024-04-10 06:14:05       29 阅读
  4. 高并发环境下的实现与优化策略

    2024-04-10 06:14:05       42 阅读
  5. 百度机器学习算法春招一二三面面经

    2024-04-10 06:14:05       32 阅读
  6. 基于Flask测试深度学习模型预测

    2024-04-10 06:14:05       39 阅读
  7. Vscode使用教程

    2024-04-10 06:14:05       34 阅读
  8. 【hive】单节点搭建hadoop和hive

    2024-04-10 06:14:05       32 阅读
  9. Hadoop 源码中使用ServiceLoader

    2024-04-10 06:14:05       39 阅读
  10. vscode 关键字记录

    2024-04-10 06:14:05       30 阅读
  11. Ajax、Fetch、Axios三者的区别

    2024-04-10 06:14:05       41 阅读