小土堆-pytorch-神经网络-网络模型的保存和读取13_笔记

保存与读取方式一:

创建2个python空文件,模拟保存和读取
保存:

import torch
import torchvision
vgg16=torchvision.models.vgg16(weights=False)
# 保存方式一: 模型结构+模型参数
torch.save(vgg16,"vgg16_method1.pth")

保存完成后,会出现这样的文件
在这里插入图片描述
读取:

import torch
# 方式一 保存方式一,加载模型
model=torch.load("vgg16_method1.pth")
print(model)

运行结果截图:
在这里插入图片描述

保存与读取方式二:

保存:

# 保存方式二: 以字典类型保存,保存它的参数(官方推荐) 模型参数
torch.save(vgg16.state_dict(),"vgg16_method2.pth")

在这里插入图片描述
读取:

# 方式二:加载模型
vgg16=torchvision.models.vgg16(weights=None)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16)

运行结果截图:
在这里插入图片描述
方式一的陷阱:

from torch import nn


class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1=nn.Conv2d(3,64,3)
        
    def forward(self,x):
        x=self.conv1(x)
        return x
    
tudui=Tudui()
torch.save(tudui,"tuidui_method1.pth")

错误的加载

# 用原方式一的方式加载
model=torch.load('torch_method.pth')
print(model)

正确的加载

# 导入包
form model_save import *
# 然后正常加载

相关推荐

最近更新

  1. TCP协议是安全的吗?

    2023-12-06 12:18:01       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2023-12-06 12:18:01       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2023-12-06 12:18:01       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2023-12-06 12:18:01       20 阅读

热门阅读

  1. 【蓝桥杯】马的遍历

    2023-12-06 12:18:01       17 阅读
  2. Spring Boot项目Service类单元测试自动生成

    2023-12-06 12:18:01       34 阅读
  3. 网络编程HTTP协议进化史

    2023-12-06 12:18:01       46 阅读
  4. 计算机设计大赛 选题推荐

    2023-12-06 12:18:01       44 阅读
  5. 【Rust】结构体与枚举

    2023-12-06 12:18:01       31 阅读
  6. 嵌入式C语言中的关键字volatile

    2023-12-06 12:18:01       32 阅读
  7. Quartus II 13.1入门使用方法

    2023-12-06 12:18:01       32 阅读