Pytorch-06 使用GPU加速计算

要在PyTorch中使用GPU加速计算,需要将模型和数据移动到GPU上进行处理。以下是上一节演示修改后的示例代码,展示了如何在训练过程中利用GPU加速计算:

import torch
import torch.nn as nn
import torch.optim as optim
import time

# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# 用于比较GPU对性能的提升程度,去掉注释时使用CPU
# device = torch.device("cpu")

# 定义一个简单的神经网络模型,并将其移动到GPU
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(60, 30)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(30, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


model = SimpleModel().to(device)

# 定义损失函数
criterion = nn.MSELoss()

# 创建Adam优化器
optimizer = optim.Adam(model.parameters(), lr=0.01)

# 准备训练数据,并将其移动到GPU
input_data = torch.randn(100000, 60).to(device)
target_data = torch.randn(100000, 1).to(device)

# 训练模型
time_start = time.time()
for epoch in range(1000):
    optimizer.zero_grad()
    output = model(input_data)
    loss = criterion(output, target_data)
    loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item()}')
time_stop = time.time()
print(f"time_spend = {time_stop - time_start} s")

在这个修改后的示例中,我们首先检查GPU是否可用,并将模型和训练数据移动到GPU设备上。通过调用.to(device)方法,模型和数据都会被转移到GPU上进行计算。接着,训练过程中的计算将在GPU上加速进行,提高训练效率。

而在我的机器上面,GPU训练时间输出为time_spend = 1.8002188205718994 s,CPU训练时间输出为time_spend = 11.982393026351929 s

相关推荐

  1. Pytorch-06 使用GPU加速计算

    2024-05-25 21:14:35       11 阅读
  2. PyTorch中常用的工具(5)使用GPU加速:CUDA

    2024-05-25 21:14:35       38 阅读
  3. 在MATLAB中进行并行计算GPU加速

    2024-05-25 21:14:35       21 阅读
  4. ubuntu22.04使用conda安装pytorch(cpu及gpu版本)

    2024-05-25 21:14:35       31 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-05-25 21:14:35       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-05-25 21:14:35       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-05-25 21:14:35       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-05-25 21:14:35       18 阅读

热门阅读

  1. Pytorch-07 完整训练测试过程

    2024-05-25 21:14:35       10 阅读
  2. c++翻转一个无符号数的二进制位

    2024-05-25 21:14:35       12 阅读
  3. C++11std::bind的简单使用

    2024-05-25 21:14:35       9 阅读
  4. el-select 组件获取整个对象

    2024-05-25 21:14:35       12 阅读
  5. K8S Secret管理之SealedSecrets

    2024-05-25 21:14:35       9 阅读
  6. c++入门

    c++入门

    2024-05-25 21:14:35      12 阅读
  7. 分布式和集群区别

    2024-05-25 21:14:35       7 阅读
  8. 华为校招机试 - 最久最少使用缓存(20240508)

    2024-05-25 21:14:35       10 阅读