【深度学习基础(4)】pytorch 里的log_softmax, nll_loss, cross_entropy的关系

一、常用的函数有: log_softmax,nll_loss, cross_entropy

1.log_softmax

log_softmax就是log和softmax合并在一起执行,log_softmax=log+softmax

2. nll_loss

nll_loss函数全称是negative log likelihood loss, 函数表达式为:f(x,class)=−x[class]
例如:假设x=[5,6,9], class=1, 则f(x,class)=−x[1]=−6

3. cross_entropy交叉熵

cross_entropy=log+softmax+nll_loss

二、代码实现

import torch
import torch.nn.functional as F

preds = torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.1, 0.1, 0.1, 0.1]])
target = torch.tensor([2, 3])

print('三种方式实现交叉熵损失')
print('----------------手动实现------------------------------')
one_hot = F.one_hot(target).float() # 对标签作 one_hot 编码
print('[1]one_hot编码target:\n', one_hot)
exp = torch.exp(preds)
print('[2]对网络预测preds求指数:\n', exp)
sum_ = torch.sum(exp, dim=1).reshape(-1, 1)  # 按行求和
softmax = exp / sum_  # 计算 softmax()
print('[3]softmax操作:\n', softmax)
log_softmax = torch.log(softmax) # 计算 log_softmax()
print('[4]softmax后取对数:\n', log_softmax)
nllloss = -torch.sum(one_hot * log_softmax) / target.shape[0]  # 标签乘以激活后的数据,求平均值,取反
print("[5]手动计算交叉熵:", nllloss)

print('----------------调用log_softmax+nll_loss实现------------------------------')
# 调用 NLLLoss() 函数计算
Log_Softmax = F.log_softmax(preds, dim=1)  # log_softmax() 激活
Nllloss = F.nll_loss(Log_Softmax, target)  # 无需对标签作 one_hot 编码
print("函数使用Nllloss计算交叉熵:", Nllloss)

print('------------------调用cross_entropy实现----------------------------')
# 直接使用交叉熵损失函数 CrossEntropy_Loss()
cross_entropy = F.cross_entropy(preds, target)  # 无需对标签作 one_hot 编码
print('函数交叉熵cross_entropy:', cross_entropy)
   

查看结果,可以看到三种方式计算的结果是一样的。
在这里插入图片描述

最近更新

  1. TCP协议是安全的吗?

    2024-03-28 06:20:04       14 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-03-28 06:20:04       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-03-28 06:20:04       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-03-28 06:20:04       18 阅读

热门阅读

  1. pytorch笔记篇:pandas之数据预处理(更新中)

    2024-03-28 06:20:04       16 阅读
  2. Android数据存储:SQLite、Room

    2024-03-28 06:20:04       17 阅读
  3. 基于Python的旅游网站数据爬虫分析

    2024-03-28 06:20:04       17 阅读
  4. docker安装postgresql数据库包含postgis扩张

    2024-03-28 06:20:04       19 阅读
  5. [XG] HTTP

    [XG] HTTP

    2024-03-28 06:20:04      18 阅读
  6. 前端学习-CSS基础-Day2

    2024-03-28 06:20:04       16 阅读
  7. 机器学习:理论、方法与应用实践

    2024-03-28 06:20:04       19 阅读
  8. 机器学习(复试)

    2024-03-28 06:20:04       18 阅读
  9. TensorFlow 的基本概念和使用场景

    2024-03-28 06:20:04       19 阅读
  10. 逆流而上的选择-积极生活,逆流而上

    2024-03-28 06:20:04       17 阅读