PyTorch中保存模型的两种方式


一、状态字典(State Dictionary)

这种保存形式将模型的参数保存为一个字典,其中包含了所有模型的权重和偏置等参数。状态字典保存了模型在训练过程中学到的参数值,而不包含模型的结构。可以使用这个字典来加载模型的参数,并将其应用于相同结构的模型。
在 PyTorch 中,您可以使用 torch.save() 函数将模型的状态字典保存到文件中,例如:

torch.save(model.state_dict(), 'model.pth')

然后,可以使用 torch.load() 函数加载状态字典并将其应用于相同结构的模型:

model = MyModel()  # 创建模型对象
model.load_state_dict(torch.load('model.pth'))

这种保存形式非常适用于仅保存和加载模型的参数,而不需要保存和加载模型的结构。

二、序列化模型(Serialized Model)

这种保存形式将整个模型(包括模型的结构、参数等)保存为一个文件。序列化模型保存了模型的完整信息,可以完全恢复模型的状态,包括模型的结构、权重、偏置以及其他相关参数。
在 PyTorch 中,您可以使用 torch.save() 函数直接保存整个模型对象,例如:

torch.save(model, 'model.pth')

然后,您可以使用 torch.load() 函数加载整个序列化模型:

model = torch.load('model.pth')

这种保存形式适用于需要保存和加载完整模型信息的情况,包括模型的结构和参数。

三、示例代码

import torch

class LinearNet(torch.nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(in_features=input_size, out_features= 5, bias=True),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features= 5, out_features=5, bias=True),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=5, out_features=output_size, bias=True)
        )

    def forward(self,x):
        return self.net(x)

square_net = LinearNet(1,1)

# square_net.load_state_dict(torch.load('weight.pth'))  #直接加载已经训练好的权重

if __name__ == '__main__':

    # print(square_net(torch.tensor([3.16],dtype=torch.float32)))
    # save 方式1
    torch.save(square_net.state_dict(), "./w1.pth")
    my_state_dict = torch.load("./w1.pth")
    print("纯state_dict:\n", my_state_dict)
    print("type:", type(my_state_dict))

    # save 方式2
    torch.save(square_net, "./w2.pth")
    my_state_dict = torch.load("./w2.pth")
    print("\n\n模型结构:\n", my_state_dict)
    print("type:", type(my_state_dict))


    # 执行结果
    '''
    纯state_dict:
    OrderedDict([('net.0.weight', tensor([[ 0.0820],
            [-0.6923],
            [ 0.5066],
            [-0.8931],
            [ 0.0460]])), ('net.0.bias', tensor([ 0.1455,  0.5106,  0.2347,  0.4903, -0.6838])), ('net.2.weight', tensor([[-0.4055, -0.2721,  0.3770, -0.2285,  0.3025],
            [-0.0416,  0.0133, -0.3834, -0.2151,  0.1454],
            [ 0.0749, -0.3664, -0.1901, -0.2829,  0.3957],
            [-0.3567,  0.2668,  0.3343, -0.3351, -0.3808],
            [ 0.4375,  0.1000,  0.1185,  0.2295, -0.3997]])), ('net.2.bias', tensor([-0.2405, -0.2751,  0.1928,  0.3970, -0.0005])), ('net.4.weight', tensor([[-0.4388, -0.2654,  0.3038,  0.2008,  0.0381]])), ('net.4.bias', tensor([0.1847]))])


    模型结构:
    LinearNet(
    (net): Sequential(
        (0): Linear(in_features=1, out_features=5, bias=True)
        (1): Sigmoid()
        (2): Linear(in_features=5, out_features=5, bias=True)
        (3): Sigmoid()
        (4): Linear(in_features=5, out_features=1, bias=True)
    )
    )
    '''

相关推荐

  1. PyTorch保存模型方式

    2024-02-23 16:50:03       54 阅读
  2. Pytorch保存模型方法

    2024-02-23 16:50:03       30 阅读
  3. MongoDB——模糊查询方法

    2024-02-23 16:50:03       58 阅读
  4. Lumeical Script------Script Prompt 输出方式

    2024-02-23 16:50:03       57 阅读
  5. Unity打印信息方式

    2024-02-23 16:50:03       50 阅读
  6. Docker部署flink集群方式

    2024-02-23 16:50:03       50 阅读
  7. XML常用模式定义方式

    2024-02-23 16:50:03       18 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-02-23 16:50:03       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-02-23 16:50:03       100 阅读
  3. 在Django里面运行非项目文件

    2024-02-23 16:50:03       82 阅读
  4. Python语言-面向对象

    2024-02-23 16:50:03       91 阅读

热门阅读

  1. Cpython和Jpython区别

    2024-02-23 16:50:03       52 阅读
  2. 中国工业废水处理行业报告

    2024-02-23 16:50:03       42 阅读
  3. Asp.Net web 文件服務快速搭建

    2024-02-23 16:50:03       48 阅读
  4. 【机器学习】机器学习是什么?

    2024-02-23 16:50:03       51 阅读
  5. SQL中为什么不要使用1=1

    2024-02-23 16:50:03       53 阅读
  6. HW面试常见知识点(新手认识版)

    2024-02-23 16:50:03       49 阅读
  7. android recyclerview 中的animation滚动中动画停止了?

    2024-02-23 16:50:03       50 阅读
  8. Android自编译Pixel3内核加入KernelSU

    2024-02-23 16:50:03       53 阅读
  9. 配置docker 支持GPU方法(Nvidia GPU)

    2024-02-23 16:50:03       56 阅读
  10. Cookies

    2024-02-23 16:50:03       42 阅读