pytorch-RNN实战-正弦曲线预测

1. 正弦数据生成

曲线如下图:
在这里插入图片描述
代码如下图:

  • 50个点构成一个正弦曲线
  • 随机生成一个0~3之间的一个值(随机的原因是防止每次都从相同的点开始,50个点的正弦曲线一样,被模型记住),值的范围区间是[start, start+10]
  • 输入x范围[0,48],预测值y范围是[1,49]

在这里插入图片描述

2. 构建网络

下图是构建的网络,注意out维度扩展出一个维度,是为了和y维度一致
在这里插入图片描述

3. 训练

loss计算采用均方差MSE,优化器采用Adam
注意:hidden_prev的自更新
在这里插入图片描述

4. 预测

预测是循环一个点一个点的预测,每次预测的点的结果作为下次点的输入,直到预测出全部点,放到predictions中。
input = x[:,0,:] 去掉了x[1,seq,1]中的seq维度,变成[1,1]
在这里插入图片描述

5. 完整代码

import  numpy as np
import  torch
import  torch.nn as nn
import  torch.optim as optim
from    matplotlib import pyplot as plt


num_time_steps = 50
input_size = 1
hidden_size = 16
output_size = 1
lr=0.01



class Net(nn.Module):

    def __init__(self, ):
        super(Net, self).__init__()

        self.rnn = nn.RNN(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=1,
            batch_first=True,
        )
        for p in self.rnn.parameters():
          nn.init.normal_(p, mean=0.0, std=0.001)

        self.linear = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden_prev):

       out, hidden_prev = self.rnn(x, hidden_prev)
       # [b, seq, h]
       out = out.view(-1, hidden_size)
       out = self.linear(out)
       out = out.unsqueeze(dim=0)
       return out, hidden_prev




model = Net()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr)

hidden_prev = torch.zeros(1, 1, hidden_size)

for iter in range(6000):
    start = np.random.randint(3, size=1)[0]
    time_steps = np.linspace(start, start + 10, num_time_steps)
    data = np.sin(time_steps)
    data = data.reshape(num_time_steps, 1)
    x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
    y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)

    output, hidden_prev = model(x, hidden_prev)
    hidden_prev = hidden_prev.detach()

    loss = criterion(output, y)
    model.zero_grad()
    loss.backward()
    # for p in model.parameters():
    #     print(p.grad.norm())
    # torch.nn.utils.clip_grad_norm_(p, 10)
    optimizer.step()

    if iter % 100 == 0:
        print("Iteration: {} loss {}".format(iter, loss.item()))

start = np.random.randint(3, size=1)[0]
time_steps = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_steps)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)

predictions = []
input = x[:, 0, :]
for _ in range(x.shape[1]):
  input = input.view(1, 1, 1)
  (pred, hidden_prev) = model(input, hidden_prev)
  input = pred
  predictions.append(pred.detach().numpy().ravel()[0])

x = x.data.numpy().ravel()
y = y.data.numpy()
plt.scatter(time_steps[:-1], x.ravel(), s=90)
plt.plot(time_steps[:-1], x.ravel())

plt.scatter(time_steps[1:], predictions)
plt.show()

6. 结果展示

图中黄色点是预测点,蓝色为实际点,前面的曲线是start不随机预测的效果,说明曲线已经被模型记住了;后面的曲线是start随机预测的效果,基本趋势和真实点是一致的。
在这里插入图片描述

相关推荐

  1. (PyTorch)TCN和RNN/LSTM/GRU结合实现时间序列预测

    2024-07-11 15:20:05       54 阅读
  2. pytorch RNN

    2024-07-11 15:20:05       38 阅读
  3. 基于pytorchRNN实现文本分类

    2024-07-11 15:20:05       56 阅读
  4. RNN时序预测

    2024-07-11 15:20:05       28 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-11 15:20:05       67 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-11 15:20:05       72 阅读
  3. 在Django里面运行非项目文件

    2024-07-11 15:20:05       58 阅读
  4. Python语言-面向对象

    2024-07-11 15:20:05       69 阅读

热门阅读

  1. react获取访问过的路由历史记录

    2024-07-11 15:20:05       24 阅读
  2. 编程范式实现思路介绍

    2024-07-11 15:20:05       19 阅读
  3. 表单验证的艺术:WebKit 支持 HTML 表单的全面解析

    2024-07-11 15:20:05       19 阅读
  4. Android --- Kotlin学习之路:基础语法学习笔记

    2024-07-11 15:20:05       24 阅读
  5. 智能制造热点词汇科普篇——工业微服务

    2024-07-11 15:20:05       22 阅读
  6. C++中的模板(二)

    2024-07-11 15:20:05       21 阅读
  7. slf4j日志框架和logback详解

    2024-07-11 15:20:05       22 阅读
  8. Redis的配置和优化

    2024-07-11 15:20:05       22 阅读
  9. springboot 抽出多个接口中都有相同的代码的方法

    2024-07-11 15:20:05       23 阅读
  10. OpenJudge | 最高的分数

    2024-07-11 15:20:05       21 阅读
  11. springmvc 如何对接接口

    2024-07-11 15:20:05       23 阅读