pytorch 中 drop_last与 nn.Parameter

1. drop_last

在使用深度学习,pytorch 的DataLoader 中,

from torch.utils.data import DataLoader

# Define your dataset and other necessary configurations
# Create DataLoader
train_loader = DataLoader(dataset, batch_size=batch_size, drop_last=True)

drop_last=True :DataLoader 中的此设置会删除不完整的最后一批(如果它小于指定的批量大小)。这确保了训练期间处理的每个批次包含相同数量的样本。

1.1 drop_last = True

dataset_size = 100
batch_size = 32
train_loader = DataLoader(dataset, batch_size=batch_size, drop_last=True)

使用 drop_last=True ,DataLoader 确保每个批次包含 32 个样本,删除不完整的最终批次。例如,在这种情况下,训练期间将处理 3 个批次(32、32、32),其余 4 个样本将不会用于训练。

适用情况:

当网络模型的初始化中,需要用到batch size 时, 这种情况下, 需要注意的是此时, drop_last = False , 会影响网络模型结构, 由于模型的初始化过程中,使用了batch size 参数, 所有此时应该设置为 True;

1.2 drop_last = False

而当 drop_last = False, 当最后一个批次中, 剩余的样本个数不足 batch 样本数目时, 会保留这剩余的样本,使用剩余的样本进行训练。

当数据不均衡, 并且某一类中样本数量很少时, 此时 drop_last = True 会严重影响到模型的精度,此时应该使用 False;

原因是,本身的某个类别中训练集和测试集的数量就已经小于batch size 时, 此时使用 drop last, 会严重该类别的训练和测试效果。

如下面的情况:

遇到了这样的问题。一共16类,第15 16类的训练集数量是15、15,测试集分别为14、5。其他1-14类训练集分别有50个,测试集均为200左右。

当我在pytorch的dataloader中设置了drop_last=True时,无论怎么训练,使用怎么样的数据增强,第15 16类才测试集上的准确率永远为0.

原因分析:
当dataloader设置了drop_last=True时,在训练时如果数据总量无法整除batch_size,那么这个dataloader就会丢掉最后一个batch,也就是说训练的时候有部分数据是被丢掉的。而我遇到的情况可能是正好把第15 16类的测试数据给丢掉了部分,导致模型很好的学习到这两类的特征。

解决方案:
将drop_last改为False,即可解决该问题。

2. nn.Parameter

在深度学习训练过程中, 通常需要自己创建出一个初始化的张量, 并且希望通过模型训练过程中, 更新该张量。

torch.randn(bt, 3, 256)

而普通的使用torch 随机初始化的方式,如上面的这种方式,
在大多数情况下,随机初始化张量不会使其参数变得可学习。在没有任何相关学习过程或梯度更新的情况下随机初始化的张量在网络训练期间不会适应或改变。

2.1 可学习参数

为了使得创建的张量,在网络训练过程中,可以得到更新。

在 PyTorch 中, nn.Parameter 是一个继承自 torch.Tensor 的类。它允许您向框架指示该张量应被视为模型参数的一部分。当您将其分配为 nn.Module 中的属性时,它在优化过程中变得可训练。

import torch
import torch.nn as nn

# Creating a tensor as a learnable parameter
param_tensor = nn.Parameter(torch.randn(1, 3))

param_tensor 将在训练过程中进行优化因为它们被视为模型可学习参数的一部分。

放到 cuda 设备上

 self.cuda_param = nn.Parameter(torch.randn(1, 2).cuda())

2.2 nn.ParameterList

同样, 当想创建一个列表都是可学习的参数时, 使用如下的方式;

self.parameters = nn.ParameterList([nn.Parameter(torch.randn(256)) for _ in range(5)])

相关推荐

  1. pytorch drop_last nn.Parameter

    2023-12-12 09:16:04       34 阅读
  2. pytorch用tensorboard

    2023-12-12 09:16:04       15 阅读
  3. pytorch深度学习

    2023-12-12 09:16:04       14 阅读
  4. pytorch深度学习

    2023-12-12 09:16:04       10 阅读

最近更新

  1. TCP协议是安全的吗?

    2023-12-12 09:16:04       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2023-12-12 09:16:04       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2023-12-12 09:16:04       19 阅读
  4. 通过文章id递归查询所有评论(xml)

    2023-12-12 09:16:04       20 阅读

热门阅读

  1. 读excel文件,借助openpyxl工具

    2023-12-12 09:16:04       30 阅读
  2. 数据结构 | Floyd

    2023-12-12 09:16:04       42 阅读
  3. MySQL全文索引布尔模式详解

    2023-12-12 09:16:04       32 阅读
  4. 【Go自学版】03-即时通信系统3

    2023-12-12 09:16:04       41 阅读
  5. Hadoop 完全分布式搭建 详细流程

    2023-12-12 09:16:04       41 阅读
  6. Kubernetes实战(十一)-重装Kubernetes集群

    2023-12-12 09:16:04       37 阅读