pytorch bert实现文本分类

以imdb公开数据集为例,bert模型可以在huggingface上自行挑选

1.导入必要的库

import os
import torch
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import BertTokenizer, BertModel, BertConfig
from torch import nn
from torch.optim import AdamW
import numpy as np
from sklearn.metrics import accuracy_score
import pandas as pd
from tqdm import tqdm

device = torch.device("cuda:0")
print(device)

2.加载和预处理数据:读取数据,将其转换为适合BERT的格式,并将评分映射到三个类别。

import random
def load_imdb_dataset_and_create_multiclass_labels(path_to_data, split="train"):
    print(f"load start: {
     split}")
    reviews = []
    labels = []  # 0 for low, 1 for medium, 2 for high
    for label in ["pos", "neg"]:
        labeled_path = os.path.join(path_to_data, split, label)
        for file in os.listdir(labeled_path):
            if file.endswith('.txt'):
                with open(os.path.join(labeled_path, file), 'r', encoding='utf-8') as f:
                    reviews.append(f.read())
                    if label == "neg":
                        # Randomly assign negative reviews to low or medium
                        labels.append(random.choice([0, 1]))  
                    else:
                        labels.append(2)  # Assign positive reviews to high
    return reviews[:1000], labels[:1000]
#加载数据集
train_texts, train_labels = load_imdb_dataset_and_create_multiclass_labels("./data/aclImdb", split="train")
test_texts, test_labels = load_imdb_dataset_and_create_multiclass_labels("./data/aclImdb", split="test")
print("load okk")
#样本数量
print("train_texts: ",len(train_texts))
print("test_texts: ",len(test_texts))

3.文本转换为BERT的输入格式

tokenizer = BertTokenizer.from_pretrained('./bert_pretrain')

def encode_texts(tokenizer, texts, max_len=512):
    input_ids = []
    attention_masks = []

    for text in texts:
        encoded = tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=max_len,
            pad_to_max_length=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        input_ids.append(encoded['input_ids'])
        attention_masks.append(encoded['attention_mask'])
    
    return torch.cat(input_ids, dim=0), torch.cat(attention_masks, dim=0)

train_inputs, train_masks = encode_texts(tokenizer, train_texts)
test_inputs, test_masks = encode_texts(tokenizer, test_texts)
print("input transfromer encode done")

4.创建TensorDataset和DataLoader

train_labels = torch.tensor(train_labels)
test_labels = torch.tensor(test_labels)

train_dataset = TensorDataset(train_inputs, train_masks, train_labels)
test_dataset = TensorDataset(test_inputs, test_masks, test_labels)

# Split the dataset into train and validation sets
train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False)

5.构建模型:使用BERT进行多分类任务

class BertForMultiLabelClassification(nn.Module):
    def __init__(self):
        super(BertForMultiLabelClassification, self).__init__()
        self.bert = BertModel.from_pretrained('./bert_pretrain')
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.bert.config.hidden_size, 3)  # 3类

    def forward(self, input_ids, attention_mask):
        _, pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=False)
        pooled_output = self.dropout(pooled_output)
        return self.classifier(pooled_output)

6.训练和评估模型

# 初始化模型、优化器和损失函数
model = BertForMultiLabelClassification()
# 使用多GPU
# if MULTI_GPU:
#     model = nn.DataParallel(model)
model.to(device)

optimizer = AdamW(model.parameters(), lr=2e-5)
loss_fn = nn.CrossEntropyLoss()

# 训练函数
def train(model, dataloader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    for batch in dataloader:
        batch = tuple(b.to(device) for b in batch)
        inputs, masks, labels = batch

        optimizer.zero_grad()

        outputs = model(input_ids=inputs, attention_mask=masks)
        loss = loss_fn(outputs, labels)
        total_loss += loss.item()

        loss.backward()
        optimizer.step()

    average_loss = total_loss / len(dataloader)
    return average_loss

# 评估函数
def evaluate(model, dataloader, loss_fn, device):
    model.eval()
    total_loss = 0
    predictions, true_labels = [], []

    with torch.no_grad():
        for batch in dataloader:
            batch = tuple(b.to(device) for b in batch)
            inputs, masks, labels = batch

            outputs = model(input_ids=inputs, attention_mask=masks)
            loss = loss_fn(outputs, labels)
            total_loss += loss.item()

            logits = outputs.detach().cpu().numpy()
            label_ids = labels.to('cpu').numpy()
            predictions.append(logits)
            true_labels.append(label_ids)

    average_loss = total_loss / len(dataloader)
    flat_predictions = np.concatenate(predictions, axis=0)
    flat_predictions = np.argmax(flat_predictions, axis=1).flatten()
    flat_true_labels = np.concatenate(true_labels, axis=0)

    accuracy = accuracy_score(flat_true_labels, flat_predictions)
    return average_loss, accuracy

# 训练和评估循环
for epoch in range(3):  # 假设训练3个周期
    train_loss = train(model, train_dataloader, optimizer, loss_fn, device)
    val_loss, val_accuracy = evaluate(model, val_dataloader, loss_fn, device)

    print(f"Epoch {
     epoch+1}")
    print(f"Train Loss: {
     train_loss:.3f}")
    print(f"Validation Loss: {
     val_loss:.3f}, Accuracy: {
     val_accuracy:.3f}")

# 在测试集上评估模型性能
test_loss, test_accuracy = evaluate(model, test_dataloader, loss_fn, device)
print(f"Test Loss: {
     test_loss:.3f}, Accuracy: {
     test_accuracy:.3f}")
#保存模型
torch.save(model.state_dict(), "./model/bert_multiclass_imdb_model.pt")

7.模型预测

from transformers import BertModel
import torch


def predict(texts, model, tokenizer, device, max_len=128):
    # 将文本编码为BERT的输入格式
    def encode_texts(tokenizer, texts, max_len):
        input_ids = []
        attention_masks = []

        for text in texts:
            encoded = tokenizer.encode_plus(
                text,
                add_special_tokens=True,
                max_length=max_len,
                pad_to_max_length=True,
                return_attention_mask=True,
                return_tensors='pt',
            )
            input_ids.append(encoded['input_ids'])
            attention_masks.append(encoded['attention_mask'])
        
        return torch.cat(input_ids, dim=0), torch.cat(attention_masks, dim=0)

    model.eval()  # 将模型设置为评估模式
    predictions = []

    input_ids, attention_masks = encode_texts(tokenizer, texts, max_len)
    input_ids = input_ids.to(device)
    attention_masks = attention_masks.to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_masks)
        logits = outputs.detach().cpu().numpy()
        predictions = np.argmax(logits, axis=1)

    return predictions

# 示例文本
texts = ["I very like the movie", "the movie is so bad"]

# 调用预测函数

# 初始化模型
device = torch.device("cuda:0")
model = BertForMultiLabelClassification()
model.to(device)

# 加载模型状态
model.load_state_dict(torch.load('./model/bert_multiclass_imdb_model.pt'))

# 将模型设置为评估模式
model.eval()

# 加载tokenizer
tokenizer = BertTokenizer.from_pretrained('./bert_pretrain')

predictions = predict(texts, model, tokenizer, device)

# 输出预测结果
for text, pred in zip(texts, predictions):
    print(f"Text: {
     text}, Predicted category: {
     pred}")

相关推荐

  1. pytorch bert实现文本分类

    2023-12-08 18:14:03       37 阅读
  2. RNN文本分类任务实战

    2023-12-08 18:14:03       32 阅读
  3. pytorch+huggingface+bert实现一个文本分类

    2023-12-08 18:14:03       31 阅读
  4. 基于pytorch 的RNN实现文本分类

    2023-12-08 18:14:03       41 阅读

最近更新

  1. TCP协议是安全的吗?

    2023-12-08 18:14:03       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2023-12-08 18:14:03       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2023-12-08 18:14:03       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2023-12-08 18:14:03       18 阅读

热门阅读

  1. 前端中的响应式布局与各个端适配

    2023-12-08 18:14:03       40 阅读
  2. 如何给一个栏位绑定监听事件

    2023-12-08 18:14:03       35 阅读
  3. js this知识点整理

    2023-12-08 18:14:03       38 阅读
  4. 理解并使用React的有状态组件和无状态组件

    2023-12-08 18:14:03       38 阅读
  5. tanstack/react-query使用手册

    2023-12-08 18:14:03       37 阅读
  6. 提高图片分辨率的方法与实践

    2023-12-08 18:14:03       39 阅读
  7. ElasticSearch 排障常用方法

    2023-12-08 18:14:03       44 阅读
  8. 判断path是文件夹,文件,还是不存在

    2023-12-08 18:14:03       42 阅读
  9. C++基础 std::funtion简说

    2023-12-08 18:14:03       40 阅读
  10. 多线程解决大数据批量导出问题(demo)

    2023-12-08 18:14:03       40 阅读
  11. VS2010 VS2015环境编译boost库

    2023-12-08 18:14:03       29 阅读