转换 pytorch 格式模型为 caffe格式模型 pth2caffemodel

基于 GitHub xxradon/PytorchToCaffe 源码,修改 example\resnet_pytorch_2_caffe.py 如下

import os
import sys
sys.path.insert(0, '.')

import torch
from torch.autograd import Variable
from torchvision.models import resnet
import pytorch_to_caffe


"""
    resnet models in pytorch format can be downloaded from
        ‘resnet18’: ‘https://download.pytorch.org/models/resnet18-5c106cde.pth’,
        ‘resnet34’: ‘https://download.pytorch.org/models/resnet34-333f7ec4.pth’,
        ‘resnet50’: ‘https://download.pytorch.org/models/resnet50-19c8e357.pth’,
        ‘resnet101’: ‘https://download.pytorch.org/models/resnet101-5d3b4d8f.pth’,
        ‘resnet152’: ‘https://download.pytorch.org/models/resnet152-b121ed2d.pth’,

"""

def show_usage(cmd):
    print( "Usage:" )
    print(   "    ", cmd, " <pytorch-model-name>  <pytorch-model-filename.pth>" )
    
def main(cmd, argv):
    if( len(argv) < 2 ):
        print( "Error! Parameter is not enough." )
        show_usage( cmd )
        exit( 1 )

    model_name = argv[0]
    input_file = argv[1]

    pure_path = os.path.splitext( input_file )
    file_name = pure_path[0]
    
    print( " model  : ",  model_name )
    print( " input  : ",  input_file )
    print( " output : ",  '{}.prototxt'.format(file_name) )
    print( "          ",  '{}.caffemodel'.format(file_name) )
    
    
    input=torch.ones([1,3,224,224])
    match model_name:
        case "resnet18":
            resnet_x = resnet.resnet18()
        case "resnet34":
            resnet_x = resnet.resnet34()
        case "resnet50":
            resnet_x = resnet.resnet50()
        case "resnet101":
            resnet_x = resnet.resnet101()
        case "resnet152":
            resnet_x = resnet.resnet152()
        case _:
            print( "Error! Unknown model name : ",  model_name )
            show_usage( cmd )
            exit( 2 )

    if( False == os.path.isfile(input_file) ):
        print( "Error! Cannot find input file : ", input_file )
        show_usage( cmd )
        exit( 3 )

    checkpoint = torch.load(input_file)
    
    resnet_x.load_state_dict(checkpoint)
    resnet_x.eval()
    pytorch_to_caffe.trans_net(resnet_x,input,model_name)
    pytorch_to_caffe.save_prototxt('{}.prototxt'.format(file_name))
    pytorch_to_caffe.save_caffemodel('{}.caffemodel'.format(file_name))
    
    
if __name__ == "__main__":
   main(sys.argv[0], sys.argv[1:])

脚本依赖pytorch,安装之。

pip install torch

运行中遇到 protobuf 版本过高问题,降级处理

pip install -U protobuf==3.20 

下载 resnet model文件后,执行脚本

python example\resnet_pytorch_2_caffe.py  resnet152  resnet152-b121ed2d.pth

相关推荐

  1. 转换 pytorch 格式模型 caffe格式模型 pth2caffemodel

    2023-12-12 04:16:05       38 阅读
  2. pytorch模型caffe模型

    2023-12-12 04:16:05       33 阅读
  3. pdf格式转换txt格式

    2023-12-12 04:16:05       20 阅读
  4. 深度学习的模型转换(.pt转换.engine)

    2023-12-12 04:16:05       14 阅读

最近更新

  1. TCP协议是安全的吗?

    2023-12-12 04:16:05       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2023-12-12 04:16:05       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2023-12-12 04:16:05       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2023-12-12 04:16:05       18 阅读

热门阅读

  1. Linux 常见面试题 Day8

    2023-12-12 04:16:05       36 阅读
  2. redis实际应用实现合集

    2023-12-12 04:16:05       34 阅读
  3. 【场景测试用例】网站

    2023-12-12 04:16:05       23 阅读
  4. Mysql高频面试题11道

    2023-12-12 04:16:05       41 阅读
  5. Mysql8和Oracle实际项目中递归查询树形结构

    2023-12-12 04:16:05       39 阅读
  6. C#深入.NET平台的软件系统分层开发

    2023-12-12 04:16:05       29 阅读