【代码学习】多标签分类 multilabel classfication | loss如何计算? | 衡量指标如何计算?

loss计算 | BCELoss(), 最后+sigmoid映射为0-1区间值

gpt解释

import torch
import torch.nn as nn

# 创建模型输出和目标标签
output = torch.tensor([0.7, 0.4, 0.2, 0.8])  # 模型的输出(概率值)
target = torch.tensor([1, 0, 0, 1])  # 实际的目标标签

# 创建 Binary Cross-Entropy Loss 对象
criterion = nn.BCELoss()

# 计算损失值
loss = criterion(torch.sigmoid(output), target.float())  # 注意需要使用 sigmoid 函数将概率值映射到 [0, 1] 区间,并将目标标签转换为浮点数
print("Binary Cross-Entropy Loss:", loss.item())

衡量指标计算 sklearn.metrics

在这里插入图片描述

代码 MultiLabelClassifier /CelebA_Classification_PyTorch_Github.ipynb

from: https://github.com/vatsalsaglani/MultiLabelClassifier/blob/master/CelebA_Classification_PyTorch_Github.ipynb

多标签label是这样的
在这里插入图片描述
最后是用sigmoid()激活函数,loss用 nn.BCELoss()
acc是这样计算的?

在这里插入图片描述

相关推荐

  1. 如何衡量机器学习分类模型(python)

    2023-12-27 19:40:04       26 阅读
  2. 如何学习计算机视觉

    2023-12-27 19:40:04       51 阅读
  3. 计算机如何学习

    2023-12-27 19:40:04       22 阅读
  4. 计算机如何学习

    2023-12-27 19:40:04       20 阅读
  5. 如何学习计算机

    2023-12-27 19:40:04       22 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2023-12-27 19:40:04       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2023-12-27 19:40:04       101 阅读
  3. 在Django里面运行非项目文件

    2023-12-27 19:40:04       82 阅读
  4. Python语言-面向对象

    2023-12-27 19:40:04       91 阅读

热门阅读

  1. OCC服务器和BCC服务器中文件同步

    2023-12-27 19:40:04       54 阅读
  2. golang发送邮件

    2023-12-27 19:40:04       45 阅读
  3. 享元设计模式

    2023-12-27 19:40:04       58 阅读