神经网络----网络模型的加载及保存

一、加载网络模型

torch.load()函数:

PyTorch 中用于加载保存的模型或张量的函数

例如:

torch.load(checkpoint_path, map_location=device)

其中, checkpoint_path 是保存模型参数的文件路径,

            map_location=device 用于将模型加载到指定的设备上。如果你在训练时使用了 GPU,并

            且想在 CPU 上进行推断或继续训练,这就很有用。map_location 参数告诉 PyTorch 将模 

            型参数加载到指定的设备上。

这句代码的作用:将路径在checkpoint_path的模型参数文件加载到设备device上面

model.load_state_dict( )函数:

将状态字典加载到模型的方法

例如:

model.load_state_dict(torch.load(checkpoint_path, map_location=device))

这行代码的目的是将从文件加载的预训练模型的状态字典应用到指定的模型 model 中。加载后,model 就包含了预训练模型的参数,可以在之后用于推断或继续训练。

其中,预训练模型的状态字典是由torch.load()函数加载的,model.load_state_dict()函数将状态字典应用到模型model上面。

二、保存网络模型

torch.save() 函数用于将对象保存到文件中,以便之后可以使用 torch.load() 函数加载它

一般用于保存模型、张量、字典等 PyTorch 对象

torch.save( )函数的基本用法:

torch.save(obj, file_path)
  • obj: 要保存的Pytorch对象,可以是模型、张量、字典等
  • file_path: 要保存到的文件路径,可以是相对路径或绝对路径

例如,保存模型参数到文件的代码可能如下所示:

import torch

# 假设 model 是 PyTorch 模型
model = ...

# 假设 file_path 是保存文件的路径
file_path = 'model.pth'

# 使用 torch.save() 保存模型参数到文件
torch.save(model.state_dict(), file_path)

在上述示例中,model.state_dict() 返回模型的参数状态字典,它包含了模型的所有可学习参数。这个字典可以通过 torch.load() 函数加载,用于初始化模型或进行模型的迁移学习等任务。 

例如,

torch.save(model.module.state_dict() if hasattr(model, "module") else model.state_dict(), os.path.join(output_dir_epoch, '{0}_model.pth'.format(epoch)))

model.module.state_dict() if hasattr(model, "module") else model.state_dict()

其中,

model.state_dict() 返回模型的当前参数状态字典

有些模型在 多 GPU  或分布式训练中可能使用了  nn.DataParallel  封装,导致模型的顶层包装是 nn.DataParallel 对象。如果是这样,那么需要使用 model.module.state_dict() 来获取实际的模型参数状态字典。这里通过 hasattr(model, "module") 来检查模型是否有 module 属性,如果有,则使用 model.module.state_dict(),否则使用 model.state_dict()

os.path.join(output_dir_epoch, '{0}_model.pth'.format(epoch))

其中,

os.path.join() 用于拼接文件路径,将模型保存在指定的目录下

output_dir_epoch 是保存模型的目录

{0}_model.pth'.format(epoch) 生成保存文件的名称,其中 {0} 会被替换为当前的 epoch 数字,确保每个模型文件都有唯一的名称,与训练的时期相关

这行代码的作用是将当前模型的参数状态字典保存到一个以 epoch 号命名的文件中

相关推荐

  1. 神经网络----网络模型保存

    2023-12-08 07:10:06       46 阅读
  2. 神经网络---网络模型保存

    2023-12-08 07:10:06       29 阅读
  3. 神经网络保存-导入

    2023-12-08 07:10:06       26 阅读
  4. Pytorch学习 day12(模型保存

    2023-12-08 07:10:06       41 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2023-12-08 07:10:06       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2023-12-08 07:10:06       100 阅读
  3. 在Django里面运行非项目文件

    2023-12-08 07:10:06       82 阅读
  4. Python语言-面向对象

    2023-12-08 07:10:06       91 阅读

热门阅读

  1. ffmpeg学习日记619-指令-透明通道视频相关指令

    2023-12-08 07:10:06       59 阅读
  2. 低代码:美味膳食或垃圾食品?

    2023-12-08 07:10:06       57 阅读
  3. python对py文件加密

    2023-12-08 07:10:06       60 阅读
  4. python中的配置config模块

    2023-12-08 07:10:06       59 阅读
  5. C# 异步

    2023-12-08 07:10:06       52 阅读
  6. 2023-简单点-python的多路复用小例子

    2023-12-08 07:10:06       66 阅读
  7. 在 CentOS 或 Red Hat 系统上安装 Citus 组件

    2023-12-08 07:10:06       60 阅读