学习pytorch17 pytorch模型保存及加载

pytorch模型保存及加载

代码

import torch
import torchvision


vgg16 = torchvision.models.vgg16(pretrained=False)

# 1. save model 1   保存模型结构及模型参数
torch.save(vgg16, './vgg16_save1.model')

# 2. save model 2   只保存模型参数 比第一种保存方法保存的文件要小
torch.save(vgg16.state_dict(), './vgg16_save2.model')

# 3. load model 1
vgg16_load1 = torch.load('./vgg16_save1.model')
print(vgg16_load1)  # 打印的是模型网络结构

# 3. load model 2
vgg16_load2 = torch.load('./vgg16_save2.model')
print(vgg16_load2)  # 打印的是模型参数
# 将参数导入到网络
vgg16.load_state_dict(vgg16_load2)
print(vgg16)

# 5. 保存模型方式1的陷阱
# 当用方法1导入模型的时候,模型结构是要已知的
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
# class MySeq2(nn.Module):
#     def __init__(self):
#         super(MySeq2, self).__init__()
#         self.model1 = Sequential(Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
#                                  MaxPool2d(2),
#                                  Conv2d(32, 32, kernel_size=5, stride=1, padding=2),
#                                  MaxPool2d(2),
#                                  Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
#                                  MaxPool2d(2),
#                                  Flatten(),
#                                  Linear(1024, 64),
#                                  Linear(64, 10)
#                                  )
#
#     def forward(self, x):
#         x = self.model1(x)
#         return x
# myseq2 = MySeq2()
# torch.save(myseq2, 'myseq_self.model')
# 当用方法1导入模型的时候,模型结构是要已知的 否则就会报下面的错误 可以在代码里重新定义 但一般都是写在另一个单独的文件里面 比如上面注释的模型结构是前面已经写在p19_nn_seq 文件里面的,执行了模型保存
# AttributeError: Can't get attribute 'MySeq2' on <module '__main__' from 'C:/工作文档/learn_pytorch/p23_save_load_model.py'>
from p19_nn_seq import *
myseq2 = torch.load('myseq_self.model')
print(myseq2)

执行结果

只打印模型参数
在这里插入图片描述
打印模型结构,在调试模式下 可以在feature–保护属性–models–0–weight下看到模型参数
在这里插入图片描述
自己写过的模型文件保存后加载
在这里插入图片描述

相关推荐

  1. Pytorch学习 day12模型保存

    2023-12-05 16:14:10       20 阅读
  2. 深度学习-Pytorch如何保存模型

    2023-12-05 16:14:10       38 阅读
  3. pytorch保存模型以及如何load部分参数

    2023-12-05 16:14:10       21 阅读
  4. PyTorch模型方法详解

    2023-12-05 16:14:10       35 阅读

最近更新

  1. TCP协议是安全的吗?

    2023-12-05 16:14:10       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2023-12-05 16:14:10       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2023-12-05 16:14:10       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2023-12-05 16:14:10       18 阅读

热门阅读

  1. mybatis中<association> 和 <collection>

    2023-12-05 16:14:10       33 阅读
  2. 嵌入式硬件基础知识——1

    2023-12-05 16:14:10       35 阅读
  3. mySQL踩坑记录

    2023-12-05 16:14:10       43 阅读
  4. SpringDocConfiguration

    2023-12-05 16:14:10       30 阅读
  5. Linux CenTOS命令备忘

    2023-12-05 16:14:10       36 阅读
  6. 【android开发-12】android中ListView的详细用法介绍

    2023-12-05 16:14:10       27 阅读
  7. openssl生成ssl证书

    2023-12-05 16:14:10       40 阅读