tensflow模型转onnx实践

一、基础知识介绍

1、TensorFlow介绍

TensorFlow™是一个基于数据流编程(dataflow programming)的符号数学系统,被广泛应用于各类机器学习(machine learning)算法的编程实现,其前身是谷歌的神经网络算法库DistBelief [1]。Tensorflow拥有多层级结构,可部署于各类服务器、PC终端和网页并支持GPU和TPU高性能数值计算,被广泛应用于谷歌内部的产品开发和各领域的科学研究 [1-2]。TensorFlow由谷歌人工智能团队谷歌大脑(Google Brain)开发和维护,拥有包括TensorFlow Hub、TensorFlow Lite、TensorFlow Research Cloud在内的多个项目以及各类应用程序接口(Application Programming Interface, API)。自2015年11月9日起,TensorFlow依据阿帕奇授权协议(Apache 2.0 open source license)开放源代码 。

2、keras介绍

Keras是一个由Python编写的开源人工神经网络库,可以作为Tensorflow、Microsoft-CNTK和Theano的高阶应用程序接口,进行深度学习模型的设计、调试、评估、应用和可视化 。
Keras在代码结构上由面向对象方法编写,完全模块化并具有可扩展性,其运行机制和说明文档有将用户体验和使用难度纳入考虑,并试图简化复杂算法的实现难度 [1]。Keras支持现代人工智能领域的主流算法,包括前馈结构和递归结构的神经网络,也可以通过封装参与构建统计学习模型 。在硬件和开发环境方面,Keras支持多操作系统下的多GPU并行计算,可以根据后台设置转化为Tensorflow、Microsoft-CNTK等系统下的组件 。

3、onnx

ONNX是一种开放格式,专门用于表示机器学习模型。它定义了一套通用的运算符,这些运算符是构建机器学习和深度学习模型的基础单元,同时ONNX还定义了一种通用的文件格式。这些特性使得AI开发者能够跨多种框架、工具、运行时和编译器使用模型。

二、基础环境介绍

实际工作中,模型使用的版本和框架可能各不相同,在做模型转换或者模型迁移工作的过程中,一般先讲各个框架的模型格式转换为通用的中间格式,比如onnx,正如我们所做的,將 TensorFlow框架编写训练的模型转换为onnx格式。

1、框架版本

tensflow版本为1.14,keras版本为2.2.4

三、环境搭建

1、在 anconda当中新建一个虚拟环境,指定python版本为3.7

conda create -n py37_tf14 python=3.7

在这里插入图片描述
检查python和pip版本:

python
pip -V

在这里插入图片描述

2、安装依赖

1)安装tensflow和 keras

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple tensorflow==1.14.0
Requirement already satisfied: setuptools>=41.0.0 in d:\programs\anaconda3\envs\py37_tf14\lib\site-packages (from tensorboard<1.15.0,>=1.14.0->tensorflow==1.14.0) (65.6.3)
Collecting importlib-metadata>=4.4
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ff/94/64287b38c7de4c90683630338cf28f129decbba0a44f0c6db35a873c73c4/importlib_metadata-6.7.0-py3-none-any.whl (22 kB)
Collecting MarkupSafe>=2.1.1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9b/c1/9f44da5ca74f95116c644892152ca6514ecdc34c8297a3f40d886147863d/MarkupSafe-2.1.3-cp37-cp37m-win_amd64.whl (17 kB)
Collecting typing-extensions>=3.6.4
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ec/6b/63cc3df74987c36fe26157ee12e09e8f9db4de771e0f3404263117e75b95/typing_extensions-4.7.1-py3-none-any.whl (33 kB)
Collecting zipp>=0.5
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/5b/fa/c9e82bbe1af6266adf08afb563905eb87cab83fde00a0a08963510621047/zipp-3.15.0-py3-none-any.whl (6.8 kB)
Installing collected packages: tensorflow-estimator, zipp, wrapt, typing-extensions, termcolor, six, protobuf, numpy, MarkupSafe, grpcio, gast, astor, absl-py, werkzeug, keras-preprocessing, importlib-metadata, h5py, google-pasta, markdown, keras-applications, tensorboard, tensorflow
Successfully installed MarkupSafe-2.1.3 absl-py-1.4.0 astor-0.8.1 gast-0.5.4 google-pasta-0.2.0 grpcio-1.57.0 h5py-3.8.0 importlib-metadata-6.7.0 keras-applications-1.0.8 keras-preprocessing-1.1.2 markdown-3.4.4 numpy-1.21.6 protobuf-4.24.2 six-1.16.0 tensorboard-1.14.0 tensorflow-1.14.0 tensorflow-estimator-1.14.0 termcolor-2.3.0 typing-extensions-4.7.1 werkzeug-2.2.3 wrapt-1.15.0 zipp-3.15.0

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple keras==2.2.4

Requirement already satisfied: keras-applications>=1.0.6 in d:\programs\anaconda3\envs\py37_tf14\lib\site-packages (from keras==2.2.4) (1.0.8)
Requirement already satisfied: h5py in d:\programs\anaconda3\envs\py37_tf14\lib\site-packages (from keras==2.2.4) (3.8.0)
Installing collected packages: scipy, pyyaml, keras
Successfully installed keras-2.2.4 pyyaml-6.0.1 scipy-1.7.3

测试tensflow是否安装完成

import tensorflow as tf
hello = tf.constant('hello Tensorflow')
sess = tf.Session()
print(sess.run(hello))

报错:

TypeError: Descriptors cannot not be created directly

需要降低 protobuf版本到3.19.0

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple protobuf==3.19.0

成功运行测试代码!

2)安装转换时需要的依赖

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple h5py
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple onnx
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple keras2onnx
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple tf2onnx

3)检查 安装包

(py37_tf14) PS C:\Users\lenovo> pip list
Package              Version
-------------------- ---------
absl-py              1.4.0
astor                0.8.1
certifi              2022.12.7
charset-normalizer   3.2.0
coloredlogs          15.0.1
fire                 0.5.0
flatbuffers          23.5.26
gast                 0.5.4
google-pasta         0.2.0
grpcio               1.57.0
h5py                 2.10.0
humanfriendly        10.0
idna                 3.4
importlib-metadata   6.7.0
Keras                2.2.4
Keras-Applications   1.0.8
Keras-Preprocessing  1.1.2
keras2onnx           1.7.0
Markdown             3.4.4
MarkupSafe           2.1.3
mpmath               1.3.0
numpy                1.21.6
onnx                 1.8.0
onnxconverter-common 1.13.0
onnxruntime          1.14.1
onnxtk               0.0.1
packaging            23.1
pip                  22.3.1
protobuf             3.20.3
pyreadline           2.1
PyYAML               6.0.1
requests             2.31.0
scipy                1.7.3
setuptools           65.6.3
six                  1.16.0
sympy                1.10.1
tensorboard          1.14.0
tensorflow           1.14.0
tensorflow-estimator 1.14.0
termcolor            2.3.0
tf2onnx              1.15.1
typing_extensions    4.7.1
urllib3              2.0.4
Werkzeug             2.2.3
wheel                0.38.4
wincertstore         0.2
wrapt                1.15.0
zipp                 3.15.0
(py37_tf14) PS C:

三、转换方法问题及解决方案总结

1、将keras的h5模型转化为onnx

import tensorflow as tf



from keras.models import load_model
import onnx
import os
os.environ["TF_KERAS"]='1' 

import keras2onnx
keras_model = tf.keras.models.load_model('model/ar_crnn.h5',compile=False)

onnx_model = keras2onnx.convert_keras(keras_model, keras_model.name)

# tf2onnx.save_model(onnx_model, "ar_crnn.onnx")


'''
报错信息
File "D:\Programs\anaconda3\envs\py37_tf14\lib\site-packages\onnxconverter_common\onnx_ops.py", line 815, in apply_reshape
raise ValueError('There can only be one -1 in the targeted shape of a Reshape but got %s' % desired_shape)
ValueError: There can only be one -1 in the targeted shape of a Reshape but got [-1, -1, 1152]
'''

暂时没有找到解决方案,github官方说已经解决,但提升到最新版本还是报错!

参考官方文档

https://pypi.org/project/keras2onnx/

2、使用tf2onnx 转换为onnx

import tensorflow as tf
import tf2onnx

# Load the Keras model
keras_model = tf.keras.models.load_model('model/ar_crnn.h5')

# Convert the model to ONNX format
onnx_model = tf2onnx.convert.from_keras(keras_model)

# Save the ONNX model to a file
# tf2onnx.save_model(onnx_model, 'my_model.onnx')



'''
raise ValueError("Tensor name '{0}' is invalid.".format(node.input[0]))
ValueError: Tensor name 'batch_normalization_1/cond/ReadVariableOp/Switch:1' is invalid.
'''

没有batch_normalization算子,可能因为tf2onnx不支持keras的h5模型

3、尝试通过 tensorflow模型中介的方式转换:h5→tf→onnx

import tensorflow as tf

model_path = './model/ar_crnn.h5'                    # 模型文件
model = tf.keras.models.load_model(model_path)
model.save('tfmodel', save_format='tf')

'''
 'Saving the model as SavedModel is not supported in TensorFlow 1.X'
'''

报错,表示tensorflow1.x不支持保存tf文件,为避免升级tensorflow2.x后带来其他版本不兼容的问题,未再对这种方式做尝试

4、使用h5py文件直接转成onnx的方式,暂时没有探索成功

import h5py
import onnx
from onnx import helper, shape_inference
import json

model_file = 'model/ar_crnn.h5'
# 打开h5文件
with h5py.File(model_file, 'r') as f:
    # 获取所有子集
    model_config = json.loads(f.attrs['model_config'].decode('utf-8'))
    
    print(model_config)
    weights = []
    f.visit(lambda name: weights.append(name) if isinstance(f[name],h5py.Dataset) else None)


onnx_model = helper.make_model(model_config)
'''
TypeError: Parameter to CopyFrom() must be instance of same class: expected onnx.GraphProto got dict.
'''

报错,在网上的例子中,直接使用f.attrs[‘model_config’]的方法获取的模型结构就可以直接作为helper.make_model()的参数,初始化模型; 但是实际运行时,helper.make_model()期待的是onnx.GraphProto这种类型的入参,在做onnx.GraphProto实例化的时候,仍然缺少必要参数,仍然失败。

5、使用pb中介转换的方式: h5 → pb → onnx, 成功!

1)、下载工具类:https://link.zhihu.com/?target=https%3A//github.com/amir-abdi/keras_to_tensorflow

因为没有官方工具,所以这里下载的是在git开源的工具

2)、使用工具类进行模型转换, 将h5模型转换为pb模型

python keras_to_tensorflow.py  --input_model="model\ar_crnn.h5"  --output_model="crnn.pb"

3)、使用tf2onnx,将pb转换为onnx

python -m tf2onnx.convert --graphdef crnn.pb --output ar_crnn_1.onnx --inputs the_input:-1 --outputs output_new/truediv:0

四、转换完成后,对比精度,检查onnx

编写精度测试代码: 随机生成一张图片, 分别用keras模型和 onnx模型进行推理,计算差值, 求最大值

# 读取h5模型和 onnx模型进行推理,对比结果
import onnxruntime
from keras.models import load_model
import numpy as np
def test_h5_onnx_precision(h5_path, onnx_path, batch_size):
    # 读取h5模型
    keras_model = load_model(h5_path,compile=False)
    input_data = np.random.random(size=(batch_size, 48, 16, 3)).astype(np.float32)
    h5_res = keras_model.predict([input_data])
    # onnx推理
    ort_session = onnxruntime.InferenceSession(onnx_path)
    model_inputs = ort_session.get_inputs()
    ort_inputs = {model_inputs[0].name: input_data}
    onnx_output = ort_session.run(['output_new/truediv:0'], ort_inputs)[0]

    res = h5_res - onnx_output
    print(res)
    print(res.max())

test_h5_onnx_precision('model/ar_crnn.h5', 'ar_crnn.onnx', 1000)

相关推荐

  1. tensorflow | onnx模型pb

    2024-04-04 19:34:02       20 阅读
  2. 模型部署之——ONNX模型RKNN

    2024-04-04 19:34:02       37 阅读
  3. 【深度学习】Pytorch模型Onnx

    2024-04-04 19:34:02       14 阅读
  4. bert pytorch模型onnx,并改变输入输出

    2024-04-04 19:34:02       28 阅读
  5. ONNX模型

    2024-04-04 19:34:02       6 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-04-04 19:34:02       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-04-04 19:34:02       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-04-04 19:34:02       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-04-04 19:34:02       18 阅读

热门阅读

  1. 【设计模式】-单例模式

    2024-04-04 19:34:02       17 阅读
  2. Qt中实现域(Unix)套接字通信

    2024-04-04 19:34:02       20 阅读
  3. LeetCode-热题100:121. 买卖股票的最佳时机

    2024-04-04 19:34:02       16 阅读
  4. 2024年最新github之PHP语言开源项目top50排行榜

    2024-04-04 19:34:02       19 阅读
  5. 久菜盒子|留学|推荐信|international trade(国际贸易)

    2024-04-04 19:34:02       16 阅读
  6. 穿透 雪崩 击穿

    2024-04-04 19:34:02       20 阅读
  7. FastGpt流程

    2024-04-04 19:34:02       16 阅读