transformers DataCollator介绍

本博客主要介绍 transformers DataCollator的使用

from transformers import AutoTokenizer, AutoModel, \
    DataCollatorForSeq2Seq, DataCollatorWithPadding, \
    DataCollatorForTokenClassification, DefaultDataCollator, DataCollatorForLanguageModeling

PRETRAIN_MODEL = "E:\pythonWork\models\chinese-roberta-wwm-ext"
tokenizer = AutoTokenizer.from_pretrained(PRETRAIN_MODEL)
model = AutoModel.from_pretrained(PRETRAIN_MODEL)

texts = ['今天天气真好。', "我爱你"]
encodings = tokenizer(texts)


labels = [list(range(len(each))) for each in texts]

inputs = [{"input_ids":t, "labels": l} for t,l in zip(encodings['input_ids'], labels)]



dc1 = DefaultDataCollator()
dc2 = DataCollatorForTokenClassification(tokenizer=tokenizer)
dc3 = DataCollatorWithPadding(tokenizer=tokenizer)
dc4 = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
d5 = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
d6 = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)


print('DataCollatorForTokenClassification')
print(dc2(inputs))



print('DataCollatorWithPadding')
print(dc3(encodings))

print('DataCollatorForSeq2Seq')
print(dc4(inputs))


print(123)

DataCollatorForTokenClassification

观察如下输出,token分类任务中,每个token都应该有一个标签,所以存在以下数量关系:

ids==labels

ids进行了填充,

labels进行了填充

attention_mask进行了填充

DataCollatorForTokenClassification
{'input_ids': tensor([[ 101,  791, 1921, 1921, 3698, 4696, 1962,  511,  102],
        [ 101, 2769, 4263,  872,  102,    0,    0,    0,    0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 0, 0, 0, 0]]), 'labels': tensor([[   0,    1,    2,    3,    4,    5,    6, -100, -100],
        [   0,    1,    2, -100, -100, -100, -100, -100, -100]])}

DataCollatorWithPadding

ids进行了填充,

labels进行了填充

attention_mask进行了填充

DataCollatorWithPadding
{'input_ids': tensor([[ 101,  791, 1921, 1921, 3698, 4696, 1962,  511,  102],
        [ 101, 2769, 4263,  872,  102,    0,    0,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 0, 0, 0, 0]])}

DataCollatorForSeq2Seq

ids !=labels  (注意和DataCollatorForTokenClassification进行区分)

ids进行了填充,

labels进行了填充

attention_mask进行了填充

DataCollatorForSeq2Seq
{'input_ids': tensor([[ 101,  791, 1921, 1921, 3698, 4696, 1962,  511,  102],
        [ 101, 2769, 4263,  872,  102,    0,    0,    0,    0]]), 'labels': tensor([[   0,    1,    2,    3,    4,    5,    6],
        [   0,    1,    2, -100, -100, -100, -100]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 0, 0, 0, 0]])}

相关推荐

  1. SQLMap介绍

    2024-06-07 11:32:01       45 阅读
  2. GAN 介绍

    2024-06-07 11:32:01       64 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-06-07 11:32:01       98 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-06-07 11:32:01       106 阅读
  3. 在Django里面运行非项目文件

    2024-06-07 11:32:01       87 阅读
  4. Python语言-面向对象

    2024-06-07 11:32:01       96 阅读

热门阅读

  1. 事务 ---- mysql

    2024-06-07 11:32:01       31 阅读
  2. python的视频处理FFmpeg库使用

    2024-06-07 11:32:01       31 阅读
  3. C# 证件照替换底色与设置背景图---PaddleSegSharp

    2024-06-07 11:32:01       26 阅读
  4. 详解MySQL的间隙锁

    2024-06-07 11:32:01       29 阅读
  5. mm-qcamera-daemon主函数分析

    2024-06-07 11:32:01       32 阅读
  6. Mysql基础进阶速成版

    2024-06-07 11:32:01       33 阅读
  7. 在 Vue 中实现算法可视化

    2024-06-07 11:32:01       31 阅读
  8. Elixir学习笔记——关键字列表和映射

    2024-06-07 11:32:01       35 阅读
  9. SwiftUI二列表和导航

    2024-06-07 11:32:01       23 阅读
  10. SATA、Cache、Ctrl、Alt等计算机名词的正确读音/发音

    2024-06-07 11:32:01       104 阅读