3、TensorRT学习笔记之ONNX转engine

        摘要:主要学习记录了ONNX转TensorRT流程、代码。末尾有完整代码。

3.1 创建TensorRT的日志记录器

log = trt.Logger()

3.2 创建bulider对象 

        使用日志记录器创建 TensorRT Builder 对象,并通过Builder创建network并从该网络生成engine

        其中:trt.OnnxParser(network, log)需要传入两个参数。一个是已创建network,一个是日志记录器

builder = trt.Builder(log)                # 使用日志记录器创建 TensorRT Builder 对象
parser = trt.OnnxParser(network, log)     # 从network生成engine

3.3 设置engine参数

# 创建 Builder Config 对象
config = builder.create_builder_config()            
# 设置 TensorRT 推理引擎使用的最大工作空间大小,单位为字节。指定最大可用显存
config.max_workspace_size = workspace * 1 << 30     

3.4 定义network并加载ONNX解析器

         通过builder创建一个空网络,什么都没有,需要将ONNX的模型结构信息写入创建的空network。

flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network()        # 通过Builder创建network,此时的network还是一个空网络
parser = trt.OnnxParser(network, log)     

         查看ONNX是否解析成功并将ONNX中的模型结构等信息写入network

# 查看是否解析成功,同时将模型结构写进了network
if not parser.parse_from_file(str(onnx)):       
    raise RuntimeError(f'failed to load ONNX file: {onnx}')

3.5 获取网络的输入输出

# 可能不是num_inputs,根据实际情况来。
inputs = [network.get_input(i) for i in range(network.num_inputs)]            
outputs = [network.get_output(i) for i in range(network.num_outputs)]

3.6 动态输入

if dynamic:
    im = torch.zeros(1, 3, *imgsz).to(device)    # 我这儿输入是im = torch.zeros(1,3,640,640)
    if im.shape[0] <= 1:
        # log.warning(f"{trt} WARNING ⚠️ --dynamic model requires maximum --batch-size argument")
        print('x')
    profile = builder.create_optimization_profile()
    for inp in inputs:
        profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
    config.add_optimization_profile(profile)

3.7 检查设备是否支持FP16(半精度)推理

其中:

  • builder.platform_has_fast_fp16:用于检查当前设备是否可以进行半精度计算。
  • half:自定义bool参数,用于决定是否半径都推理
  • config.set_flag(trt.BuilderFlag.FP16):set_flag方法来设置config对象的标志,将FP16标志添加到flags中
if builder.platform_has_fast_fp16 and half:
    config.set_flag(trt.BuilderFlag.FP16)

3.8 写入engine,并序列化model

with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
    t.write(engine.serialize())

        如果希望在trt模型中加入classes(其余信息类似)。

with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
    classes = (['person', 'car'])
    add_meta_to_model(t, classes, type='trt')
    t.write(engine.serialize())

3.9 完整代码

import numpy as np
import tensorrt as trt
import torch
import logging

# logger to capture errors, warnings, and other information during the build and inference phases
TRT_LOGGER = trt.Logger()


def build_engine(onnx, dynamic=True, half=True):
    # f = onnx.with_suffix('.engine')
    f = 'trt.engine'
    # 1、创建日志记录器
    log = trt.Logger()
    # 2、创建builder对象
    builder = trt.Builder(log)
    # 3、创建 Builder Config 对象
    config = builder.create_builder_config()
    # 4、将workspace*1 二进制左移30位后的10进制
    workspace = 1
    config.max_workspace_size = workspace * 1 << 30         # 设置 TensorRT 推理引擎使用的最大工作空间大小,单位为字节
    # 5、定义networko并加载ONNX解析器
    flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    network = builder.create_network(flag)
    parser = trt.OnnxParser(network, log)

    if not parser.parse_from_file(str(onnx)):       # 查看是否解析成功
        raise RuntimeError(f'failed to load ONNX file: {onnx}')

    # 6、获得网络的输入输出
    inputs = [network.get_input(i) for i in range(network.num_inputs)]
    outputs = [network.get_output(i) for i in range(network.num_outputs)]

    # 7.判断是否动态输入
    if dynamic:
        im = torch.zeros(1,3,640,640)
        if im.shape[0] <= 1:
            # log.warning(f"{trt} WARNING ⚠️ --dynamic model requires maximum --batch-size argument")
            print('x')
        profile = builder.create_optimization_profile()
        for inp in inputs:
            profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
        config.add_optimization_profile(profile)
    # 判断是否支持FP16推理

    if builder.platform_has_fast_fp16 and half:
        config.set_flag(trt.BuilderFlag.FP16)
    # build engine 文件的写入  这里的f是前面定义的engine文件
    with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
        # 序列化model
        t.write(engine.serialize())
    return f, None

if __name__ == '__main__':
    engine, context = build_engine(r'D:\zy\Yolo\yolov8-ZY\yolov8n.onnx')

3.10 遇见的问题

1、AttributeError: 'tensorrt.tensorrt.Builder' object has no attribute 'max_workspace_size'

原因是:tensorrt8.0以上删除了max_workspace_size属性。

  • 降低tensorRT版本到7.x版本
  • 或者如下
config = builder.create_builder_config()            # 创建 Builder Config 对象
config.max_workspace_size = workspace * 1 << 30     # 设置 TensorRT 推理引擎使用的最大工作空间大小,单位为字节

上一篇:2、TensorRT学习笔记之PT转ONNX、可视化ONNX

下一篇:正在学习、持续更新(实战,瑞芯微RK3588部署yolov8检测模型)

参考文章:利用python版tensorRT导出engine【以yolov5为例】_yolov5 export得到的engine和tensorrt的engine-CSDN博客

相关推荐

  1. TensorRT加速推理入门-1:PytorchONNX

    2024-04-26 10:58:02       33 阅读
  2. 模型部署——ONNX模型RKNN

    2024-04-26 10:58:02       37 阅读
  3. Ubuntu下安装ONNXONNX-TensorRT、Protobuf和TensorRT

    2024-04-26 10:58:02       40 阅读
  4. 【深度学习】Pytorch模型Onnx

    2024-04-26 10:58:02       14 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-04-26 10:58:02       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-04-26 10:58:02       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-04-26 10:58:02       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-04-26 10:58:02       18 阅读

热门阅读

  1. npm详解

    2024-04-26 10:58:02       13 阅读
  2. npm/yarm常用命令

    2024-04-26 10:58:02       11 阅读
  3. 企业网络安全的全方位解决方案

    2024-04-26 10:58:02       11 阅读
  4. 大数据任务运维方案

    2024-04-26 10:58:02       12 阅读
  5. 【13】编写shell-备份mysql数据

    2024-04-26 10:58:02       11 阅读
  6. Vue中嵌套路由(子路由)的使用

    2024-04-26 10:58:02       10 阅读
  7. 前端如何优化工程

    2024-04-26 10:58:02       15 阅读
  8. 基于python的NBA球员数据可视化分析的设计与实现

    2024-04-26 10:58:02       11 阅读