bert-NER 转化成 onnx 模型

保存模型

加载模型

from transformers import AutoTokenizer, AutoModel, AutoConfig

NER_MODEL_PATH = './save_model'
ner_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_PATH)
ner_config = AutoConfig.from_pretrained(NER_MODEL_PATH)
ner_model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_PATH)
ner_model.eval()

测试ner效果

在这里插入图片描述

测试速度

在这里插入图片描述

导出到onnx

# !pip install onnx onnxruntime -i https://pypi.tuna.tsinghua.edu.cn/simple/

# 导出 onnx 模型
import onnxruntime
from itertools import chain
from transformers.onnx.features import FeaturesManager

config = ner_config
tokenizer = ner_tokenizer
model = ner_model
output_onnx_path = "bert-ner.onnx"

onnx_config = FeaturesManager._SUPPORTED_MODEL_TYPE['bert']['sequence-classification'](config)
dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework='pt')

torch.onnx.export(
    model,
    (dummy_inputs,),
    f=output_onnx_path,
    input_names=list(onnx_config.inputs.keys()),
    output_names=list(onnx_config.outputs.keys()),
    dynamic_axes={
        name: axes for name, axes in chain(onnx_config.inputs.items(), onnx_config.outputs.items())
    },
    do_constant_folding=True,
    use_external_data_format=onnx_config.use_external_data_format(model.num_parameters()),
    enable_onnx_checker=True,
    opset_version=onnx_config.default_onnx_opset,
)

加载ONNX模型

自定义pipeline

from onnxruntime import SessionOptions, GraphOptimizationLevel, InferenceSession

class PipeLineOnnx:
    def __init__(self, tokenizer, onnx_path, config):
        self.tokenizer = tokenizer
        self.config = config  # label2id, id2label
        options = SessionOptions() # initialize session options
        options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
        # 设置线程数
#         options.intra_op_num_threads = 4
        # 这里的路径传上一节保存的onnx模型地址
        self.session = InferenceSession(
            onnx_path, sess_options=options, providers=["CPUExecutionProvider"]
        )
        # disable session.run() fallback mechanism, it prevents for a reset of the execution provider
        self.session.disable_fallback() 

    def __call__(self, text):
        inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt')
        ids = inputs["input_ids"]
        inputs_offset = self.tokenizer.encode_plus(text, return_offsets_mapping=True).offset_mapping
        inputs_detach = {k: v.detach().cpu().numpy() for k, v in inputs.items()}

        # 运行 ONNX 模型
        # 这里的logits要有export的时候output_names相对应

        output = self.session.run(output_names=['logits'], input_feed=inputs_detach)[0]
        logits = torch.tensor(output)

        num_labels = len(self.config.label2id)
        active_logits = logits.view(-1, num_labels) # shape (batch_size * seq_len, num_labels)
        softmax = torch.softmax(active_logits, axis=1)
        scores = torch.max(softmax, axis=1).values.cpu().detach().numpy()
        flattened_predictions = torch.argmax(active_logits, axis=1) # shape (batch_size*seq_len,) - predictions at the token level

        tokens = self.tokenizer.convert_ids_to_tokens(ids.squeeze().tolist())
        token_predictions = [self.config.id2label[i] for i in flattened_predictions.cpu().numpy()]
        wp_preds = list(zip(tokens, token_predictions)) # list of tuples. Each tuple = (wordpiece, prediction)

        ner_result = [{"index": idx, "word":i,"entity":j, "start": k[0], "end": k[1], "score": s} for idx, (i,j,k,s) in enumerate(zip(tokens, token_predictions, inputs_offset, scores)) if j != 'O']
        return post_process(ner_result)
        

def allow_merge(a, b):
    a_flag, a_type = a.split('-')
    b_flag, b_type = b.split('-')
    if b_flag == 'B' or a_flag == 'E':
        return False
    if a_type != b_type:
        return False
    if (a_flag, b_flag) in [
        ("B", "I"),
        ("B", "E"),
        ("I", "I"),
        ("I", "E")
    ]:
        return True
    return False

def divide_entities(ner_results):
    divided_entities = []
    current_entity = []

    for item in sorted(ner_results, key=lambda x: x['index']):
        if not current_entity:
            current_entity.append(item)
        elif allow_merge(current_entity[-1]['entity'], item['entity']):
            current_entity.append(item)
        else:
            divided_entities.append(current_entity)
            current_entity = [item]
    divided_entities.append(current_entity)
    return divided_entities

def merge_entities(same_entities):
    def avg(scores):
        return sum(scores)/len(scores)
    return {
        'entity': same_entities[0]['entity'].split("-")[1],
        'score': avg([e['score'] for e in same_entities]),
        'word': ''.join(e['word'].replace('##', '') for e in same_entities),
        'start': same_entities[0]['start'],
        'end': same_entities[-1]['end']
    }

def post_process(ner_results):
    return [merge_entities(i) for i in divide_entities(ner_results)]

加载模型

from transformers import AutoTokenizer, AutoConfig

NER_MODEL_PATH = './save_model'
ner_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_PATH)
ner_config = AutoConfig.from_pretrained(NER_MODEL_PATH)

pipe2 = PipeLineOnnx(ner_tokenizer, "bert-ner.onnx", config=ner_config)

测试效果

在这里插入图片描述

测试速度

在这里插入图片描述

相关推荐

  1. BERT 微调中文 NER 模型

    2024-05-10 11:30:05       14 阅读
  2. 【深度学习】Pytorch模型Onnx

    2024-05-10 11:30:05       16 阅读
  3. bert pytorch模型onnx,并改变输入输出

    2024-05-10 11:30:05       30 阅读
  4. 使用 bert-base-chinese-ner 模型实现中文NER

    2024-05-10 11:30:05       13 阅读
  5. onnx模型转换到rknn脚本

    2024-05-10 11:30:05       7 阅读

最近更新

  1. 获取和设置Spring Cookie

    2024-05-10 11:30:05       0 阅读
  2. Spring——配置说明

    2024-05-10 11:30:05       0 阅读
  3. springboot中在filter中用threadlocal存放用户身份信息

    2024-05-10 11:30:05       0 阅读
  4. LDAP技术解析:打造安全、高效的企业数据架构

    2024-05-10 11:30:05       1 阅读
  5. android 替换设置-安全里面的指纹背景图片

    2024-05-10 11:30:05       1 阅读
  6. Node.js的应用场景

    2024-05-10 11:30:05       1 阅读

热门阅读

  1. 【设计模式学习笔记】设计模式的分类

    2024-05-10 11:30:05       21 阅读
  2. CSS怎样命名才能更好的理解

    2024-05-10 11:30:05       10 阅读
  3. 机器视觉在锂电芯生产中的全方位检测应用

    2024-05-10 11:30:05       11 阅读
  4. redis+SQL server等保测评命令

    2024-05-10 11:30:05       10 阅读
  5. 情感分类学习笔记(2)

    2024-05-10 11:30:05       14 阅读
  6. 列存储数据库之MonetDB

    2024-05-10 11:30:05       13 阅读
  7. 冒泡排序(Bubble Sort)

    2024-05-10 11:30:05       10 阅读