NLP(5)-softmax和交叉熵

前言

仅记录学习过程,有问题欢迎讨论

感觉全连接层就像一个中间层转换数据的形态的,或者说预处理数据?

代码

softmax就是把输出的y 归一化,把结果转化为概率值!,在分类问题中很常见。
而交叉熵是一种损失函数,也是在分类问题中使用,通常搭配着softmax使用;可以计算分布概率之间的差异;期望是两个概率分布的更相似。

# softMAX --概率归一化
# 主要把输出的y 的可能性 sum = 1
# 比如 判断是否为 猫狗 y = 【猫,狗,都不是】

# 【123====[ e^1/(e^1+e^2+e^3) ,e^2/(e^1+e^2+e^3) ,e^3/(e^1+e^2+e^3) ]

# 损失函数:
# 均方差 (y-y_pre)**2/n
# 交叉熵:y_true=[1,0,0] y_pred=[0.5,0.4,0.1] =loss==>| 1*log0.5+0*log0.4+0 | = 0.3

import torch
import torch.nn as nn
import numpy as np

##使用torch计算交叉熵
ce_loss = nn.CrossEntropyLoss()
# 假设有3个样本,每个都在做3分类
pred = torch.FloatTensor([[0.3, 0.1, 0.3],
                          [0.9, 0.2, 0.9],
                          [0.5, 0.4, 0.2]])  # n*class_num
# 正确的类别 == [0,1,0][0,0,1][1,0,0]
target = torch.LongTensor([1, 2, 0])
print(target)
loss = ce_loss(pred, target)
print(loss, "torch输出交叉熵")


# 实现softMax函数 x为矩阵格式!!
def softmax(x):
    return np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True)


# 验证softmax函数
# print(torch.softmax(pred, dim=1))
# print(softmax(pred.numpy()))

# 将输入转化为onehot矩阵 == [0,1,0][0,0,1][1,0,0]
def to_one_hot(target, shape):
    one_hot_target = np.zeros(shape)
    # enumerate 也是一种循环 格式为【index,value】
    for i, t in enumerate(target):
        one_hot_target[i][t] = 1
    return one_hot_target


# target_matrix =  to_one_hot(target,pred.shape)
# print(target_matrix)

# 实现交叉熵
def cross_entropy(pred, target):
    batch_size, class_num = pred.shape
    # 先归一化
    pred = softmax(pred)
    # 变为矩阵格式
    target = to_one_hot(target,pred.shape)
    # 每一列求
    entropy = -np.sum(target * np.log(pred), axis=1)
    return sum(entropy) / batch_size

print(cross_entropy(pred.numpy(), target.numpy()), "手动实现交叉熵")



相关推荐

  1. NLP5)-softmax交叉

    2024-04-23 20:46:03       36 阅读
  2. 深度学习 - softmax交叉损失计算

    2024-04-23 20:46:03       27 阅读
  3. 深入理解交叉损失CrossEntropyLoss - Softmax

    2024-04-23 20:46:03       32 阅读
  4. NLP - Softmax与层次Softmax对比

    2024-04-23 20:46:03       22 阅读
  5. PyTorch交叉理解

    2024-04-23 20:46:03       29 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-04-23 20:46:03       91 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-04-23 20:46:03       97 阅读
  3. 在Django里面运行非项目文件

    2024-04-23 20:46:03       78 阅读
  4. Python语言-面向对象

    2024-04-23 20:46:03       88 阅读

热门阅读

  1. web server apache tomcat11-15-proxy

    2024-04-23 20:46:03       37 阅读
  2. wsl ubuntu18.04升级为cmake-3.15.3

    2024-04-23 20:46:03       34 阅读
  3. 前端宝藏图:寻找技术之旅的星辰大海

    2024-04-23 20:46:03       35 阅读
  4. Python爬取网易云平台

    2024-04-23 20:46:03       29 阅读
  5. 【Linux开发 第十一篇】rpm和yum

    2024-04-23 20:46:03       30 阅读
  6. 傅立叶变换与拉普拉斯变换的区别与联系?

    2024-04-23 20:46:03       28 阅读