Bert模型输出:last_hidden_state转换为pooler_output

1. BERT模型的输出

在BERT模型中,last_hidden_statepooler_output是两个不同的输出。

(1) last_hidden_state:

    last_hidden_state是指BERT模型中最后一个隐藏层的隐藏状态。它是一个三维张量,其形状为[batch_size, sequence_length, hidden_size]。其中,batch_size是输入序列的批量大小,sequence_length是输入序列的长度,hidden_size是BERT模型的隐藏层大小(通常为768)。
 last_hidden_state保存了输入序列中每个token对应的隐藏状态,这些隐藏状态经过多层的Transformer编码器处理得到。在多数任务中,可以直接使用这个张量进行下游任务的训练或者特征提取。

(2) pooler_output:
     pooler_output是指BERT模型中经过一个特殊的池化层后得到的句子级别表示。它是一个二维张量,其形状为[batch_size, hidden_size]。
pooler_output是通过对BERT模型最后一个隐藏层的第一个token([CLS] token)的隐藏状态应用一个全连接层得到的。这个全连接层的参数在预训练过程中被学习得到。pooler_output可以看作是整个输入序列的压缩表示,通常用于句子级别的任务,如文本分类。

       总的来说,last_hidden_state提供了序列中每个token的隐藏状态信息,而pooler_output提供了整个句子的语义表示。

2. last_hidden_state转换为pooler_output

     在BERT模型中,last_hidden_state是最后一个隐藏层的隐藏状态,而pooler_output是通过应用一个全连接层(通常是一个线性变换加上激活函数)到last_hidden_state中的特殊token([CLS] token)得到的。

      首先从last_hidden_state中提取出每个样本的第一个token(即[CLS] token)的隐藏状态。然后,我们定义了一个线性层pooler_layer,将隐藏状态映射到与BERT模型的隐藏大小相同的空间。最后,我们应用了tanh激活函数,得到 pooler_output,这是整个句子的语义表示。这个pooler_output可以用于句子级别的任务,例如文本分类。

      请确保poor_layer的权重是正确初始化的。通常情况下,应该使用预训练的BERT模型的权重来初始化它。可以在实例化poor_layer时进行这样的初始化。如果使用的是transformers库,它提供了加载预训练BERT模型并提取pooler_output的方法。要使用预训练的BERT模型的权重来初始化线性层 pooler_layer,可以从预训练的BERT模型中加载权重,并将这些权重用作 pooler_layer的初始权重。通常情况下,会使用Hugging Face的 transformers库来加载预训练的BERT模型。

       以下是一个示例代码,演示如何使用transformers库来加载预训练的BERT模型,并使用其中的权重来初始化 pooler_layer:

from transformers import BertModel, BertTokenizer

#加载预训练的Bert模型和tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
extractor = BertModel.from_pretrained('bert-base-uncased')

#text是原始文本数据
x = tokenizer(text, padding=True, truncation=True, max_length=256, return_tensors="pt").to(
            device)

x = extractor(**x)

#获取hidden_state
x1= x['last_hidden_state']

# 定义一个线性层,将最后一个隐藏层的第一个token的隐藏状态映射到pooler_output
pooler_layer = nn.Linear(768, 768).to(device)

# 使用BERT模型的权重来初始化pooler_layer的权重
with torch.no_grad():
  pooler_layer.weight.copy_(extractor.pooler.dense.weight)
  pooler_layer.bias.copy_(extractor.pooler.dense.bias)

# 获取CLS token的隐藏状态(最后隐藏层的第一个token),取出每个样本的第一个token的隐藏状态
cls_token_state = x1[:, 0, :].to(device)

## 应用线性层并使用激活函数
x1 = torch.tanh(pooler_layer(cls_token_state)).to(device)

#直接获取pooler_output
x2=x['pooler_output'].to(device)

       在这个示例中,我们首先从预训练的BERT模型中加载了tokenizer和BERT模型。然后,我们创建了一个与BERT模型隐藏大小相同的线性层 pooler_layer。最后,我们使用`bert_model.pooler.dense`中的权重来初始化`pooler_layer`的权重。这样,`pooler_layer`就被正确初始化了,并可以用于将`last_hidden_state`变换为`pooler_output`。最后x1和x2的结果相同。

相关推荐

  1. Bert模型输出:last_hidden_state转换pooler_output

    2024-03-21 07:46:06       37 阅读
  2. bert pytorch模型转onnx,并改变输入输出

    2024-03-21 07:46:06       45 阅读
  3. 使用poi-tl填充word模板,并转化pdf输出

    2024-03-21 07:46:06       35 阅读
  4. c++ 模拟 三维数组输入 string转化int

    2024-03-21 07:46:06       42 阅读
  5. 深度学习的模型转换(.pt转换.engine)

    2024-03-21 07:46:06       37 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-03-21 07:46:06       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-03-21 07:46:06       100 阅读
  3. 在Django里面运行非项目文件

    2024-03-21 07:46:06       82 阅读
  4. Python语言-面向对象

    2024-03-21 07:46:06       91 阅读

热门阅读

  1. 【工具】mac 环境配置

    2024-03-21 07:46:06       44 阅读
  2. 啥是大语言模型LLM

    2024-03-21 07:46:06       44 阅读
  3. mongodb进阶聚合查询各种写法

    2024-03-21 07:46:06       41 阅读
  4. 多数据源 - dynamic-datasource | 事务支持

    2024-03-21 07:46:06       43 阅读
  5. 面试常问问题

    2024-03-21 07:46:06       42 阅读
  6. 洛谷P6866 [COCI2019-2020#5] Emacs

    2024-03-21 07:46:06       38 阅读
  7. gitee上传存储文件、下载文件

    2024-03-21 07:46:06       40 阅读
  8. 开源IT自动化运维工具Ansible Playbook介绍

    2024-03-21 07:46:06       35 阅读
  9. OpenCV特征检测与描述符模块

    2024-03-21 07:46:06       42 阅读
  10. C++_opencv中图像深度、通道和对应数据类型

    2024-03-21 07:46:06       40 阅读
  11. 【MySql】SQLite和MySQL的区别

    2024-03-21 07:46:06       39 阅读
  12. AWS Sagemaker详解

    2024-03-21 07:46:06       36 阅读
  13. Hive面试重点

    2024-03-21 07:46:06       46 阅读
  14. Hive自定义UDF函数

    2024-03-21 07:46:06       41 阅读
  15. 【面试自测】Spring

    2024-03-21 07:46:06       39 阅读