Pytorch中的nn.Embedding()

模块的输入是一个索引列表,输出是相应的词嵌入。

Embedding.weight(Tensor)–形状模块(num_embeddings,Embedding_dim)的可学习权重,初始化自(0,1)。
也就是说,pytorch的nn.Embedding()是可以自动学习每个词向量对应的w权重的。

import torch
import torch.nn as nn
embedding = nn.Embedding(9, 3)
# a batch of 2 samples of 4 indices each
input = torch.LongTensor([[1,2,4,5,6,7,8,1,1,1,6,7,5],[4,3,2,1,6,7,8,1,1,1,6,7,5]])
#这里的input可以里的数字可以表示为embedding的索引.索引数据的shape是没有限制的,但是input中的数值不能超过nn.Embedding(9,3)中的9的.
a = embedding(input)
print(a)

相关推荐

  1. pytorch@作用

    2024-04-07 06:02:02       38 阅读
  2. pytorchwheel文件

    2024-04-07 06:02:02       45 阅读
  3. PyTorch 批量规范化

    2024-04-07 06:02:02       46 阅读
  4. PyTorchFX图

    2024-04-07 06:02:02       47 阅读
  5. 谈谈Pytorchdataset

    2024-04-07 06:02:02       47 阅读
  6. pytorch梯度裁剪

    2024-04-07 06:02:02       41 阅读
  7. Pytorchnn.Embedding()

    2024-04-07 06:02:02       38 阅读
  8. Pytorchself.parameters()

    2024-04-07 06:02:02       35 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-04-07 06:02:02       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-04-07 06:02:02       101 阅读
  3. 在Django里面运行非项目文件

    2024-04-07 06:02:02       82 阅读
  4. Python语言-面向对象

    2024-04-07 06:02:02       91 阅读

热门阅读

  1. Redis过期删除策略和内存淘汰机制

    2024-04-07 06:02:02       45 阅读
  2. 前端node使用WebSocket实现实时通信例子

    2024-04-07 06:02:02       33 阅读
  3. Android ContentProvider基础知识学习笔记

    2024-04-07 06:02:02       39 阅读
  4. vue 生命周期

    2024-04-07 06:02:02       38 阅读
  5. [蓝桥杯 2023 国 B] 双子数

    2024-04-07 06:02:02       39 阅读
  6. ARXML处理 - C#的解析代码(一)

    2024-04-07 06:02:02       32 阅读
  7. Python常用算法--排序算法【附源码】

    2024-04-07 06:02:02       42 阅读
  8. 沐瞳科技一面 客户端开发(45min)

    2024-04-07 06:02:02       43 阅读