Pytorch将标签转为One-Hot编码

一、标签映射与One-Hot编码过程

先进行标签映射,要为每个分类建立一个整数索引,对于每个样本的标签,使用整数索引创建一个长度为类别总数的二进制向量。这个向量的所有元素都是0,除了与整数索引相对应的位置,该位置的值为1。

二、pytorch的官方实现

在pytorch中实现了one hot编码,就在torch.nn.functional里面,下面是它的注释当中的示例,我们开看看:

Examples:
    >>> F.one_hot(torch.arange(0, 5) % 3)
    tensor([[1, 0, 0],
            [0, 1, 0],
            [0, 0, 1],
            [1, 0, 0],
            [0, 1, 0]])
    >>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5)
    tensor([[1, 0, 0, 0, 0],
            [0, 1, 0, 0, 0],
            [0, 0, 1, 0, 0],
            [1, 0, 0, 0, 0],
            [0, 1, 0, 0, 0]])
    >>> F.one_hot(torch.arange(0, 6).view(3,2) % 3)
    tensor([[[1, 0, 0],
             [0, 1, 0]],
            [[0, 0, 1],
             [1, 0, 0]],
            [[0, 1, 0],
             [0, 0, 1]]])

我们可以根据那自己实现的与它给出的这个示例进行比对,一样就当然没问题了。

三、手写实现

首先,在原先的函数(one_hot)当中numclass=-1,类别当然不能为1,说明这里是自动进行了计算,大家普遍使用的方式都是创建一个全零矩阵,使用 scatter_ 函数进行独热编码,作用是按照给定的索引,在指定的维度上进行赋值。

def one_hot(labels, num_classes=-1):
    """
    将标签转为独热编码, 经过测试与torch.nn.functional里面的函数测试相同
    :param labels: 标签
    :param num_classes: 默认为-1, 表示进行自动计算类别最大的那个
    Examples:
        >>> label_1 = torch.arange(0, 5) % 3
        # tensor([0, 1, 2, 0, 1])
        >>> label_2 = torch.arange(0, 6).view(3, 2) % 3
        # tensor([[0, 1], [2, 0], [1, 2]])
        >>> print(one_hot(label_1))
        tensor([[1, 0, 0],
                [0, 1, 0],
                [0, 0, 1],
                [1, 0, 0],
                [0, 1, 0]])
        >>> print(one_hot(label_1, 5))
        tensor([[1, 0, 0, 0, 0],
                [0, 1, 0, 0, 0],
                [0, 0, 1, 0, 0],
                [1, 0, 0, 0, 0],
                [0, 1, 0, 0, 0]])
        >>> print(one_hot(label_2))
        tensor([[[1, 0, 0],
                 [0, 1, 0]],
                [[0, 0, 1],
                 [1, 0, 0]],
                [[0, 1, 0],
                 [0, 0, 1]]])
    """
    if num_classes == -1:
        num_classes = int(labels.max()) + 1
    one_hot_tensor = torch.zeros(labels.size() + (num_classes,), dtype=torch.int64)
    one_hot_tensor.scatter_(-1, labels.unsqueeze(-1).to(torch.int64), 1)
    return one_hot_tensor


label_1 = torch.arange(0, 5) % 3
# tensor([0, 1, 2, 0, 1])
label_2 = torch.arange(0, 6).view(3, 2) % 3
# tensor([[0, 1], [2, 0], [1, 2]])
print(one_hot(label_1))
print(one_hot(label_1, 5))
print(one_hot(label_2))

首先是判断分类数是不是为-1,如果是就根据其中的最大值+1进行自动计算。然后创建一个契合分类数量的全零矩阵。

在这里,labels.unsqueeze(-1)用于在标签的最后一个维度上添加一个维度,以便与独热编码张量进行广播操作。

假设原始的 labels 张量的形状为 (batch_size,),那么经过 unsqueeze(-1) 操作后,形状变为 (batch_size, 1)。这样,每个样本的标签都被表示为一个列向量,而不再是一个标量。scatter_函数在最后一个维度进行操作,也就是对类别总数的维度进行操作,而 1 是要赋给相应位置的值。

labels.unsqueeze(-1) 已经确保了与 one_hot_tensor 的形状匹配,所以在这里能够正确地进行广播和赋值操作。

下面这一种是应用于分割网络当中,在保留输入标签张量形状的同时,将独热编码张量的最后一个维度设置为分类数num_classes,确保独热编码张量与输入标签张量具有相同的形状。

def get_one_hot(labels, num_classes=-1):
    """用于分割网络的one hot"""
    labels = torch.as_tensor(labels)
    ones = one_hot(labels, num_classes)
    return ones.view(*labels.size(), num_classes)


if __name__=="__main__":
    seg_labels = torch.randint(0, 3, size=[512, 512])
    print(get_one_hot(seg_labels))
    print(get_one_hot(seg_labels).shape)   # torch.Size([512, 512, 3])

你可以将这里应用于自定义dataset部分。

相关推荐

  1. Pytorch标签转为One-Hot编码

    2024-01-12 08:22:04       49 阅读
  2. PyTorch实现标签One-Hot编码的步骤解析

    2024-01-12 08:22:04       27 阅读
  3. One-hot编码

    2024-01-12 08:22:04       48 阅读
  4. 机器学习 - one-hot编码技术

    2024-01-12 08:22:04       26 阅读
  5. PyTorch 稀疏函数解析:embedding 、one_hot详解

    2024-01-12 08:22:04       61 阅读
  6. 机器学习之独热编码One-Hot

    2024-01-12 08:22:04       58 阅读
  7. PHPHTML标签转化为图片

    2024-01-12 08:22:04       41 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-01-12 08:22:04       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-01-12 08:22:04       101 阅读
  3. 在Django里面运行非项目文件

    2024-01-12 08:22:04       82 阅读
  4. Python语言-面向对象

    2024-01-12 08:22:04       91 阅读

热门阅读

  1. selenium无法定位元素问题

    2024-01-12 08:22:04       63 阅读
  2. 树莓派ubuntu:hdmi与wifi冲突问题

    2024-01-12 08:22:04       47 阅读
  3. 架构师常用的ChatGPT通用提示词模板

    2024-01-12 08:22:04       54 阅读
  4. flutter base64图片保存到相册

    2024-01-12 08:22:04       61 阅读
  5. ubuntu18.04安装部署环境

    2024-01-12 08:22:04       48 阅读
  6. 油烟净化器电源安全,保障健康餐饮生活

    2024-01-12 08:22:04       51 阅读
  7. JVM初识

    JVM初识

    2024-01-12 08:22:04      62 阅读
  8. Django的模板语言

    2024-01-12 08:22:04       38 阅读