保存与读取方式一:
创建2个python空文件,模拟保存和读取
保存:
import torch
import torchvision
vgg16=torchvision.models.vgg16(weights=False)
# 保存方式一: 模型结构+模型参数
torch.save(vgg16,"vgg16_method1.pth")
保存完成后,会出现这样的文件
读取:
import torch
# 方式一 保存方式一,加载模型
model=torch.load("vgg16_method1.pth")
print(model)
运行结果截图:
保存与读取方式二:
保存:
# 保存方式二: 以字典类型保存,保存它的参数(官方推荐) 模型参数
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
读取:
# 方式二:加载模型
vgg16=torchvision.models.vgg16(weights=None)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16)
运行结果截图:
方式一的陷阱:
from torch import nn
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.conv1=nn.Conv2d(3,64,3)
def forward(self,x):
x=self.conv1(x)
return x
tudui=Tudui()
torch.save(tudui,"tuidui_method1.pth")
错误的加载
# 用原方式一的方式加载
model=torch.load('torch_method.pth')
print(model)
正确的加载
# 导入包
form model_save import *
# 然后正常加载