bert_base_chinese入门

import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
model_name="bert-base-chinese"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 带有语言模型头的模型
model = AutoModelForMaskedLM.from_pretrained(model_name)

input_text='人生该如何起头'
inx_token=tokenizer.encode(input_text)# 索引序列,加了开始索引101和结束索引102
#开始标记[CLS],结束标记[SEP],中间每个索引对应一个token
text_token=tokenizer.convert_ids_to_tokens(inx_token)
with torch.no_grad():
    output=model(torch.tensor([inx_token]))
print(type(output))
last_hidden=output['logits']
print(last_hidden.shape,last_hidden)# (1,9,21128)(batch_size,seq_len,d_model)
print(len(last_hidden[0,0,:]),max(last_hidden[0,0,:]),min(last_hidden[0,0,:]))
from transformers import AutoModel
# 不带语言模型头的模型
no_head_model=AutoModel.from_pretrained(model_name)
tensor_token=torch.tensor([inx_token])# tensor小写是int64,Tensor大写是float32
print(tensor_token.shape,tensor_token.dtype)
with torch.no_grad():
    no_head_output=no_head_model(tensor_token)
last_hidden=no_head_output['last_hidden_state']
print(last_hidden.shape,last_hidden)# (1,9,768)
from transformers import AutoModelForSequenceClassification
#分类模型
class_model=AutoModelForSequenceClassification.from_pretrained(model_name)
with torch.no_grad():
    class_output=class_model(tensor_token)
class_output=class_output['logits']
print(class_output.shape,class_output)# 形状:(1,2)
input_text1='我家的小狗是黑色的'
input_text2='我家的小狗是什么颜色的呢?'
#句子对,开始索引101,句子间分割102,第二个句子的结束102
# inx_token=tokenizer.encode(input_text1,input_text2)
# text_token=tokenizer.convert_ids_to_tokens(inx_token)
# print(text_token)# [CLS],[SEP],[SEP],索引对应分词字符串
# return_tensors:返回pytorch数据,padding:不够长度的填充0,truncation:过长的阶段
inx_tok=tokenizer(input_text2,input_text1,\
                     return_tensors='pt',padding=True, truncation=True)
input_ids=inx_tok['input_ids']
token_type_ids=inx_tok['token_type_ids']
print(type(input_ids),type(token_type_ids))
from transformers import AutoModelForQuestionAnswering
aq=AutoModelForQuestionAnswering.from_pretrained(model_name)
with torch.no_grad():
    aq_output=aq(input_ids,token_type_ids=token_type_ids)
start_inxes=aq_output['start_logits'][0]
end_inxes=aq_output['end_logits'][0]
print(len(start_inxes),len(end_inxes))
start=torch.argmax(start_inxes)
end=torch.argmax(end_inxes)
(input_text1+input_text2)[start:end]
# 定义问题和上下文  
question = "你好,请问今天天气怎么样?"  
context = "今天是晴天,气温适中,非常适合户外活动。" 
# 使用分词器对问题和上下文进行编码  
inputs = tokenizer(question, context, return_tensors='pt', \
                   padding=True, truncation=True)
input_ids = inputs['input_ids']  
attention_mask = inputs['attention_mask'] 
# 在Transformers库中,模型并不是通过数字本身来识别分割符的,
# 而是通过分词器(Tokenizer)对输入文本的处理来识别这些特殊标记。
# 在不计算梯度的情况下进行推理  
with torch.no_grad():  
    outputs =aq(input_ids, attention_mask=attention_mask)  
# 获取起始和结束位置的得分  
start_logits = outputs.start_logits[0] 
end_logits = outputs.end_logits[0] 
# 显示得分(通常会有很多,对应文本中的每个token)  
display(start_logits, end_logits) 
# 获取最大值的索引:使用torch.argmax或其他类似函数,
# 找到start_logits和end_logits中最大值对应的索引。
# 这些索引表示答案在输入文本中的起始和结束位置。
print(torch.argmax(start_logits),torch.argmax(end_logits))
# 获取最大值的索引  
start_index = torch.argmax(start_logits)  
end_index = torch.argmax(end_logits) 
# 确保索引在有效范围内  
max_len = input_ids.size(1) 
start_index = start_index.item() 
end_index = end_index.item() 
start_index = max(0, start_index)  
end_index = min(max_len - 1, end_index)  
# 确保起始位置在结束位置之前  
if start_index > end_index:  
    # 可以选择忽略这种情况,或者设置一个默认的答案或错误消息  
    answer = "无效的预测!"  
else:  
    # 使用分词器将token IDs转换回文本  
    tokenizer = AutoTokenizer.from_pretrained(model_name)  
    answer_tokens = tokenizer.convert_ids_to_tokens(input_ids[0, start_index:end_index+1]) 
    print(answer_tokens)
    answer = tokenizer.decode(input_ids[0, start_index:end_index+1])  
print(f"Answer: {answer}")


 

相关推荐

  1. docker从入门入土

    2024-03-22 05:54:03       39 阅读
  2. 入门 PyTorch

    2024-03-22 05:54:03       42 阅读
  3. C++<span style='color:red;'>入门</span>

    C++入门

    2024-03-22 05:54:03      35 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-03-22 05:54:03       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-03-22 05:54:03       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-03-22 05:54:03       19 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-03-22 05:54:03       20 阅读

热门阅读

  1. python 之 装饰器(Decorators)

    2024-03-22 05:54:03       16 阅读
  2. shell和linux的关系

    2024-03-22 05:54:03       17 阅读
  3. PostgresSQL中的死锁和锁等待

    2024-03-22 05:54:03       18 阅读
  4. 二分图试炼之棋盘覆盖

    2024-03-22 05:54:03       17 阅读
  5. 如何搭建数据中心安全架构?

    2024-03-22 05:54:03       19 阅读
  6. oracle pctfree&pctused介绍

    2024-03-22 05:54:03       18 阅读
  7. 工大智信智能听诊科技与健康

    2024-03-22 05:54:03       18 阅读
  8. List 的 Diff 功能

    2024-03-22 05:54:03       20 阅读
  9. Mysql——索引下推

    2024-03-22 05:54:03       20 阅读
  10. 如何利用ChatGPT写出高质量的学术论文

    2024-03-22 05:54:03       19 阅读
  11. RabbitMQ的安装

    2024-03-22 05:54:03       19 阅读