bert pytorch模型转onnx,并改变输入输出

#Example Codes:
def createNewGraph(inFile):
    """
    https://github.com/microsoft/onnxruntime/issues/11783
    """
    model = onnx.load(inFile)
    graph_def = model.graph
    opset_imports = [onnx.helper.make_opsetid(domain="", version=17), onnx.helper.make_opsetid('com.microsoft', 1)]

    model_def = onnx.helper.make_model(
        graph_def,
        producer_name='test',
        opset_imports=opset_imports)
    return model_def

    

def bert():
  """
  https://huggingface.co/bert-base-uncased
  git lfs install
  git clone https://huggingface.co/bert-base-uncased
  pip install onnxruntime==1.13.1   #3 inputs-->embedlayernormalization
  pip install onnxruntime==1.10.0   #2 inputs-->embedlayernormalization, 1 input-->
  pip install transformers==4.20.1
  """
  import torch
  # pip install transformers
  from transformers import BertTokenizer, BertModel
  from torch.autograd import Variable
  from onnxruntime import SessionOptions, InferenceSession, GraphOptimizationLevel
  from onnxruntime.transformers.fusion_options import FusionOptions
  from onnxruntime.transformers.optimizer import optimize_by_onnxruntime, optimize_model
  from onnx import helper

  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
  model = BertModel.from_pretrained("bert-base-uncased")

  input_ids = Variable(torch.randint(0, 30522, (1,512)))
  token_types_ids = Variable(torch.zeros((1,512), dtype=int))
  attention_mask = Variable(torch.zeros((1,512), dtype=int))
  axes = {
   0: "batch_size", 1: "seq_len"}
  axes_o1 = {
   0: "batch_size"}
  
  torch.onnx.export(
    model,
    (input_ids, attention_mask , token_types_ids),
    #(input_ids, None, token_types_ids),
    'bert-base-uncased.onnx',
    #export_params=True,         
    do_constant_folding=True,
    input_names=["input_ids", "attention_mask", "token_type_ids"],
    #input_names=["input_ids", "token_type_ids"],
    output_names=["output_start_logits", "output_end_logits"],
    dynamic_axes={
   
          "input_ids": axes,
          "attention_mask": axes,
          "token_type_ids": axes,
          "output_start_logits": axes,
          "output_end_logits": axes_o1,
    },
    verbose=True, opset_version=17,
    )

  in_path = 'bert-base-uncased.onnx'

  fusion_options = FusionOptions('bert')
  fusion_options.enable_bias_gelu = False
  fusion_options.enable_skip_layer_norm = False
  fusion_options.enable_bias_skip_layer_norm = False

  m = optimize_model(
    in_path,
    model_type='bert',
    num_heads=12,
    hidden_size=768,
    opt_level=0,
    optimization_options=fusion_options,
    use_gpu=False
  )

  print(m.get_fused_operator_statistics())
  m.save_model_to_file('bert-base-uncased_optimized.onnx', use_external_data_format=False)
  model_final = createNewGraph('bert-base-uncased_optimized.onnx')
  onnx.save(model_final, 'bert-base-uncased_final.onnx')

#----------------------------------write shape:
def changeInDim(model, shape):
  inputs = model.graph.input
  for i, input in enumerate(inputs):
    for j, dim in enumerate(input.type.tensor_type.shape.dim):
      print(f"dim[{
     i}][{
     j}]: {
     dim}--{
     dim.dim_value}-->{
     shape[i][j]}")
      dim.dim_value = shape[i][j]

def changeOtDim(model, shape):
  outputs = model.graph.output
  for i, input in enumerate(outputs):
    for j, dim in enumerate(input.type.tensor_type.shape.dim):
      print(f"dim[{
     i}][{
     j}]: {
     dim}--{
     dim.dim_value}-->{
     shape[i][j]}")
      dim.dim_value = shape[i][j]

def run_changeInOut():

  inFile =  "bert-base-uncased_final.onnx"
  outFile = "bert-base-uncased_shaped.onnx"
  model = onnx.load(inFile)
  changeInDim(model, ([1, 384], [1, 384], [1, 384]))
  changeOtDim(model, ([1, 384, 768], [1, 768]))
  inferred_model = onnx.shape_inference.infer_shapes(model)
  onnx.checker.check_model(inferred_model)
  onnx.save(inferred_model, outFile)

相关推荐

  1. bert pytorch模型onnx改变输入输出

    2024-01-24 17:12:02       46 阅读
  2. C 练习实例75-输入一个整数,将其反输出

    2024-01-24 17:12:02       39 阅读
  3. tensorflow | onnx模型pb

    2024-01-24 17:12:02       45 阅读
  4. 解决PyTorch ONNX模型每次输出结果不稳定的问题

    2024-01-24 17:12:02       58 阅读
  5. 模型部署之——ONNX模型RKNN

    2024-01-24 17:12:02       52 阅读
  6. 模拟输入用户名和密码登录

    2024-01-24 17:12:02       39 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-01-24 17:12:02       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-01-24 17:12:02       100 阅读
  3. 在Django里面运行非项目文件

    2024-01-24 17:12:02       82 阅读
  4. Python语言-面向对象

    2024-01-24 17:12:02       91 阅读

热门阅读

  1. Unity Asset store下载资源修改位置

    2024-01-24 17:12:02       64 阅读
  2. ChatGPT 和文心一言哪个更好用?

    2024-01-24 17:12:02       56 阅读
  3. Qt容器QMap(映射)

    2024-01-24 17:12:02       57 阅读
  4. cookie in selenium & 定时更新token

    2024-01-24 17:12:02       58 阅读
  5. go语言网络编程通信案例

    2024-01-24 17:12:02       52 阅读
  6. 自定义通用返回对象

    2024-01-24 17:12:02       58 阅读
  7. 在直播软件中使用RTSP协议

    2024-01-24 17:12:02       47 阅读
  8. 深入解析MVCC:多版本并发控制的数据库之道

    2024-01-24 17:12:02       43 阅读
  9. 施工图纸上的常用符号

    2024-01-24 17:12:02       52 阅读
  10. linux批量新增用户、linux批量删除用户

    2024-01-24 17:12:02       52 阅读