李沐深度学习-多项式函数拟合试验

d2lzh_pytorch模块跳转连接

import torch
import numpy as np
import sys

sys.path.append("路径")
import d2lzh_pytorch as d2l

'''
-----------------------------生成人工数据集
样本数n=200
特征数=3
三阶多项式y=1.2x-3.4x^2+5.6x^3+5+ε
'''
n_train, n_test, true_w, true_b = 100, 100, [1.2, -3.4, 5.6], 5
sample_features = torch.randn(n_train + n_test, 1)  # 200x1   单算的一个特征
poly_sample_features = torch.cat((sample_features, torch.pow(sample_features, 2), torch.pow(sample_features, 3)),
                                 dim=1)  # 组合成3个特征
# 因为poly_features取列的元素时,没有对列加[]限制,所以取出来的值不保有原来的维度,而是成为了一维张量,所以labels相应的也是一维张量
labels = true_w[0] * poly_sample_features[:, 0] + true_w[1] * poly_sample_features[:, 1] + true_w[
    2] * poly_sample_features[:, 2] + true_b
labels += torch.tensor(np.random.normal(0, 0.01, (labels.size())), dtype=torch.float)  # 以上是为了得到真实labels

# print(poly_sample_features, '\n', poly_sample_features[:2], '\n', poly_sample_features[:, 2],
# '\n', poly_sample_features[:, :2])

'''
-----------------------------------------------------定义,训练和测试模型
'''
# 尝试使用不同复杂度的模型来拟合生成的数据集
num_epochs, loss = 100, torch.nn.MSELoss()

'''
以下函数思路设计:
    1. 参数传入训练数据集样本,测试数据集样本,训练标签,测试标签
    2. 设计网络,网络输入特征计算格式
    3. 设计数据读取,读取训练数据集
    4. 循环更新迭代步骤,目的是为了优化w,b
    5. 循环外使用全批量训练数据集和测试数据集,在已经更新好的w,b的基础上进行损失计算
    6. 画图
'''


def fit_and_plot(train_features, test_features, train_labels, test_labels, label):
    net = torch.nn.Linear(train_features.shape[-1], 1)
    # Linear 自动初始化了模型参数
    batch_size = min(10, train_labels.shape[0])
    dataset = torch.utils.data.TensorDataset(train_features, train_labels)
    train_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True)

    optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
    train_ls, test_ls = [], []
    for _ in range(num_epochs):
        for X, y in train_iter:  # 里面的循环作用只是为了更迭模型参数
            y_hat = net(X)
            l = loss(y_hat, y.view(y_hat.size()))  # 计算每一批的平均损失
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
        train_labels = train_labels.view(-1, 1)
        test_labels = test_labels.view(-1, 1)  # 做这一步形状改变,是为了接下来的循环外的损失计算,损失计算时要求形状相同
        train_ls.append(loss(net(train_features), train_labels).item())  # 一个循环周期后记录一次更新后的参数的损失表现
        test_ls.append(loss(net(test_features), test_labels).item())  # 直接一整个没有分批量就进行了损失计算,使用的是最新的w,b
    print(f'final epoch: train loss', train_ls[-1], 'test loss', test_ls[-1])
    d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss', label,
                 range(1, num_epochs + 1), test_ls, ['train', 'test'])
    print('weight:', net.weight.data,
          '\nbias:', net.bias.data)


'''
-----------------------------------------------------------------三阶多项式函数拟合(正常)
'''
fit_and_plot(poly_sample_features[:n_train, :], poly_sample_features[n_train:, :], labels[:n_train], labels[n_train:],
             '正常')

'''
----------------------------------------------------------------线性函数拟合(欠拟合)
'''
# 使用的三阶多项式生成的数据标签,但是训练用的数据集是单特征的数据集,而不是三特征的数据集,这样模型就是线性模型,而不是非线性模型
# labels 还是一维张量形状
fit_and_plot(sample_features[:n_train, :], sample_features[n_train:, :], labels[:n_train], labels[n_train:], '欠拟合')

'''
----------------------------------------------------------------训练样本不足(过拟合)
'''
fit_and_plot(poly_sample_features[0:2, :], poly_sample_features[:n_train, :], labels[0:2], labels[:n_train], '过拟合')

相关推荐

  1. 深度学习-多项式函数试验

    2024-01-21 14:10:02       39 阅读
  2. 动手学习深度学习()

    2024-01-21 14:10:02       11 阅读
  3. 深度学习-激活函数/多层感知机文档

    2024-01-21 14:10:02       38 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-01-21 14:10:02       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-01-21 14:10:02       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-01-21 14:10:02       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-01-21 14:10:02       18 阅读

热门阅读

  1. Word的使用

    2024-01-21 14:10:02       29 阅读
  2. AndroidStudio

    2024-01-21 14:10:02       41 阅读
  3. SQL的五大约束作用、语法、应用场景及举例

    2024-01-21 14:10:02       26 阅读
  4. c# 释放所有嵌入资源, 到某个本地文件夹

    2024-01-21 14:10:02       37 阅读
  5. RNN神经网络 python

    2024-01-21 14:10:02       39 阅读
  6. 将Matlab图窗中的可视化保存为背景透明的矢量图

    2024-01-21 14:10:02       35 阅读
  7. GPT属于AI,是LLM的一种实现

    2024-01-21 14:10:02       31 阅读
  8. Kotlin的数据类

    2024-01-21 14:10:02       38 阅读
  9. leetcode-2788按分隔符拆分字符串

    2024-01-21 14:10:02       37 阅读