神经网络---网络模型的保存、加载

方式1:结构+参数

保存

import torch
import torchvision
from torch import nn
from torchvision.models import vgg16, VGG16_Weights

vgg16 = torchvision.models.vgg16(weights=VGG16_Weights.DEFAULT)

# 保存方式1, 模型解构+模型参数
torch.save(vgg16, 'vgg16_1.pth')

加载

from p26_model_svae import *

# 方式1 -》 保存方式1,加载模型
model = torch.load('vgg16.pth')
print(model)

方式1的陷阱
自定义网络结构如下:

import torch
import torchvision
from torch import nn

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        return x
torch.save(Tudui(), 'tudui_1.pth')

在另一个文件加载该模型,会报错
正确的调用格式需要复制原模型的类定义

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        return x
model = torch.load('tudui_1.pth')
print(model)

或者用import

from p26_model_svae import *

model = torch.load('tudui_1.pth')
print(model)

方式2 模型参数(官方推荐)

import torch
import torchvision
from torch import nn

vgg16 = torchvision.models.vgg16(pretrained=False)

# 保存方式2 模型参数(官方推荐)
torch.save(vgg16.state_dict(), 'vgg16_2.pth')

模型加载(在另一个文件加载)

# 方式2 ,加载模型
vgg16 = torchvision.models.vgg16(weights=None)
vgg16.load_state_dict(torch.load('vgg16_2.pth'))
# model2 = torch.load('vgg16_2.pth')  #字典形式
print(vgg16)

相关推荐

  1. 神经网络----网络模型保存

    2024-06-06 22:48:08       31 阅读
  2. 神经网络---网络模型保存

    2024-06-06 22:48:08       10 阅读
  3. 神经网络保存-导入

    2024-06-06 22:48:08       8 阅读
  4. Pytorch学习 day12(模型保存

    2024-06-06 22:48:08       22 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-06-06 22:48:08       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-06-06 22:48:08       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-06-06 22:48:08       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-06-06 22:48:08       20 阅读

热门阅读

  1. 汽车线束搭铁与接地设计原则

    2024-06-06 22:48:08       8 阅读
  2. 双亲委派模型

    2024-06-06 22:48:08       8 阅读
  3. C++构造器设计模式

    2024-06-06 22:48:08       8 阅读
  4. 运维开发详解

    2024-06-06 22:48:08       7 阅读
  5. C++学习笔记

    2024-06-06 22:48:08       7 阅读
  6. 常微分方程 (ODE) 和 随机微分方程 (SDE)

    2024-06-06 22:48:08       12 阅读
  7. 【面试宝藏】Go并发编程面试题

    2024-06-06 22:48:08       6 阅读
  8. Linux学习—Linux环境下的网络设置

    2024-06-06 22:48:08       9 阅读
  9. 【力扣】不同的子序列

    2024-06-06 22:48:08       7 阅读
  10. c time(NULL) time(time_t *p) 区别

    2024-06-06 22:48:08       9 阅读