PyTorch实现标签到One-Hot编码的步骤解析

代码:

import torch

class_num = 10

batch_size = 4

label = torch.LongTensor(batch_size, 1).random_() % class_num

print(label.size())

one_hot = torch.zeros(batch_size, class_num).scatter_(1, label, 1)

print(one_hot)

输出:

torch.Size([4, 1])

tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],

[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],

[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],

[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])

注意:

label的形状必须是[n,1]的,也就是必须是二维的,且第二个维度长度为1,如果是一维度的,则需要升维度,代码如下:

import torch

class_num = 10

batch_size = 4

label = torch.LongTensor(batch_size).random_() % class_num

print(label.size())

label = torch.unsqueeze(label,dim=1)

print(label.size())

相关推荐

  1. PyTorch实现标签One-Hot编码步骤

    2024-07-22 10:42:03       20 阅读
  2. Pytorch标签转为One-Hot编码

    2024-07-22 10:42:03       42 阅读
  3. One-hot编码

    2024-07-22 10:42:03       41 阅读
  4. 机器学习 - one-hot编码技术

    2024-07-22 10:42:03       18 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-22 10:42:03       52 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-22 10:42:03       54 阅读
  3. 在Django里面运行非项目文件

    2024-07-22 10:42:03       45 阅读
  4. Python语言-面向对象

    2024-07-22 10:42:03       55 阅读

热门阅读

  1. DM数据库导出excel表结构

    2024-07-22 10:42:03       14 阅读
  2. 探索Python元类的奥秘:定义与实用应用

    2024-07-22 10:42:03       15 阅读
  3. 经常进行工作总结,有何重要作用呢?

    2024-07-22 10:42:03       15 阅读
  4. C++:istream、ostream和fstream类

    2024-07-22 10:42:03       18 阅读
  5. 4 DAY

    2024-07-22 10:42:03       14 阅读
  6. 数仓中主题域还是数据域?

    2024-07-22 10:42:03       15 阅读
  7. (day21)leecode hot100字母异位词分组

    2024-07-22 10:42:03       17 阅读
  8. WireGuard 编译安装

    2024-07-22 10:42:03       16 阅读