关于torch.nn.Embedding的浅显理解

最近在使用词嵌入向量表示我的数据标签,并且在试图理解torch.nn.Embedding函数。

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, _freeze=False, device=None, dtype=None)

这里只解释我对前两个参数的理解,这也是我唯二理解的:num_embeddings(int) – size of the dictionary of embeddings,其实就是你给Embedding函数的张量里互不相同的数的个数;embedding_dim (int) – the size of each embedding vector也即生成的词嵌入向量的最后一个维度。For example:

import torch.nn as nn
import torch

known_label_lt = nn.Embedding(3, 10)

label = torch.tensor([
    [1, 0, 1, 0, 1],
    [2, 1, 0, 2, 1],
    [1, 1, 2, 1, 0],
    [1, 1, 0, 1, 2]
]).long() # without .long(), will result in an error. 

state = known_label_lt(label)
print(state.shape)

这里输入的向量label里只能包含三个不同的数:0,1,2 。或者反过来说known_label_lt的第一个参数只能是3,known_label_lt的第二个参数就决定了label的每一个数会被扩展到10维。所以最后生成的词嵌入维度是:

torch.Size([4, 5, 10])

相关推荐

  1. 关于torch.nn.Embedding浅显理解

    2023-12-09 10:40:03       38 阅读
  2. C++ 中对 const 浅显理解

    2023-12-09 10:40:03       17 阅读
  3. 关于人工智能浅见

    2023-12-09 10:40:03       34 阅读
  4. 关于指针变量理解

    2023-12-09 10:40:03       16 阅读
  5. 关于Spring Bean容器理解

    2023-12-09 10:40:03       34 阅读

最近更新

  1. TCP协议是安全的吗?

    2023-12-09 10:40:03       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2023-12-09 10:40:03       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2023-12-09 10:40:03       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2023-12-09 10:40:03       20 阅读

热门阅读

  1. C++11多线程基本知识点

    2023-12-09 10:40:03       31 阅读
  2. 《C++新经典设计模式》之第21章 解释器模式

    2023-12-09 10:40:03       30 阅读
  3. 自动补全的 select antd react

    2023-12-09 10:40:03       41 阅读
  4. 机器学习实验三:支持向量机模型

    2023-12-09 10:40:03       40 阅读
  5. CSS video控件去掉视频播放条

    2023-12-09 10:40:03       34 阅读
  6. Element-UI 数字类型输入框

    2023-12-09 10:40:03       42 阅读
  7. day9 栈实现队列

    2023-12-09 10:40:03       35 阅读
  8. 华为交换机基本配置

    2023-12-09 10:40:03       41 阅读