解决pytorch训练的过程中内存一直增加的问题

来自:解决pytorch训练的过程中内存一直增加的问题 - 知乎

pytorch训练中内存一直增加的原因(部分)

  • 代码中存在累加loss,但每步的loss没加item()
import torch
import torch.nn as nn
from collections import defaultdict

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

model = nn.Linear(100, 400).to(device)
criterion = nn.L1Loss(reduction='mean').to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train_loss = defaultdict(float)
eval_loss = defaultdict(float)

for i in range(10000):
    model.train()
    x = torch.rand(50, 100, device=device)
    y_pred = model(x) # 50 * 400
    y_tgt = torch.rand(50, 400, device=device)

    loss = criterion(y_pred, y_tgt)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # 会导致内存一直增加,需改为train_loss['loss'] += loss.item()
    train_loss['loss'] += loss

    if i % 100 == 0:
        train_loss = defaultdict(float)
        model.eval()
        x = torch.rand(50, 100, device=device)
        y_pred = model(x) # 50 * 400

        y_tgt = torch.rand(50, 400, device=device)
        loss = criterion(y_pred, y_tgt)

        # 会导致内存一直增加,需改为eval_loss['loss'] += loss.item()
        eval_loss['loss'] += loss
以上代码会导致内存占用越来越大,解决的方法是:train_l oss['loss'] += loss.item() 以及 eval_loss['loss'] += loss.item()。值得注意的是,要复现内存越来越大的问题,模型中需要切换model.train() 和 model.eval(),train_loss以及eval_loss的作用是保存模型的平均误差(这里是累积误差),保存到tensorboard中。

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2023-12-17 14:50:02       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2023-12-17 14:50:02       101 阅读
  3. 在Django里面运行非项目文件

    2023-12-17 14:50:02       82 阅读
  4. Python语言-面向对象

    2023-12-17 14:50:02       91 阅读

热门阅读

  1. Python学习笔记第七十七天(OpenCV绘画功能)

    2023-12-17 14:50:02       57 阅读
  2. QEMU源码全解析 —— virtio(12)

    2023-12-17 14:50:02       49 阅读
  3. 深度学习常用数学知识

    2023-12-17 14:50:02       60 阅读
  4. 509.斐波那契数

    2023-12-17 14:50:02       63 阅读
  5. 自建私有git进行项目发布

    2023-12-17 14:50:02       52 阅读
  6. 八股文打卡day1——计算机网络(1)

    2023-12-17 14:50:02       54 阅读
  7. 452. Minimum Number of Arrows to Burst Balloons

    2023-12-17 14:50:02       64 阅读