1.6.丢弃法

丢弃法

动机:一个好的模型需要对输入数据的扰动足够健壮,丢弃法就是在层之间加入噪音。也可以在数据中使用噪音,等价与Tikhonov正则

无偏差的加入噪音

​ 对于数据 x x x,加入噪音后的 x ′ x' x的期望值是不变的, E [ x ′ ] = x E[x']=x E[x]=x

​ 则我们可以构造出一个简单的期望运算 E [ x ′ ] = p ⋅ 0 + ( 1 − p ) ⋅ x i 1 − p = x i E[x']=p\cdot 0+(1-p)\cdot\frac{x_i}{1-p} =x_i E[x]=p0+(1p)1pxi=xi

​ 那么可以这样处理元素:

在这里插入图片描述

​ 其中丢弃概率是超参数。常用在多层感知机的隐藏层输出上。

通常将丢弃法作用在隐藏全连接层的输出上:
h = σ ( W 1 x + b 1 ) h ′ = d r o p o u t ( h ) o = W 2 h ′ + b 2 y = s o f t m a x ( o ) h=\sigma(W_1x+b_1)\\ h' = dropout(h)\\ o = W_2h' +b_2\\ y=softmax(o) h=σ(W1x+b1)h=dropout(h)o=W2h+b2y=softmax(o)
在这里插入图片描述

​ 如图本来有5个隐藏层,但丢弃函数可能取到0,那么可能会直接消失,剩下的3个隐藏层变大。

​ 丢弃项其实是正则项,只在训练中使用,他们影响模型参数的更新。

​ 在推理过程中,丢弃法直接返回输入 h = d r o p o u t ( h ) h = dropout(h) h=dropout(h),也可以保证确定性的输出

​ 实际上丢弃法的实质是每次训练中使用一个神经网络的子集来做训练, 则多次训练后得到的是多个神经网络的平均,效果自然要好一些。

​ 现在普遍将丢弃项认为是正则项,效果和正则项基本相同。

​ 在输入数据比较简单,但神经网络比较大时,dropout可能会比较有用。

​ dropout1=0.2,dropout2=0.5:

在这里插入图片描述

​ dropout1=0.dropout2=0"

在这里插入图片描述

​ 效果出乎意料的好,说明这个模型本身就没过拟合,这时候使用dropout可能效果不好。一般的小技巧是模型设大一点,然后使用dropout来进行调整。

代码实现

import torch
from torch import nn
from d2l import torch as d2l


def dropout_layer(X, dropout):
    assert 0 <= dropout <= 1  # 丢弃概率必须在0到1之间
    if dropout == 1:
        return torch.zeros_like(X)  # 全0则全部丢弃
    if dropout == 0:
        return X  # 0则不丢弃
    mask = (torch.rand(X.shape) > dropout).float()  # rand生成0到1之间的随机数
    return mask * X / (1.0 - dropout)


num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256

# dropout1, dropout2 = 0.2, 0.5
dropout1, dropout2 = 0., 0.


# 定义具有两个隐藏层的多层感知机,每个隐藏层包含256个单元,有三个线性层,最后一个是输出层
class Net(nn.Module):
    def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2,
                 is_training=True):
        super(Net, self).__init__()
        self.num_inputs = num_inputs
        self.training = is_training
        self.lin1 = nn.Linear(num_inputs, num_hiddens1)
        self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)
        self.lin3 = nn.Linear(num_hiddens2, num_outputs)
        self.relu = nn.ReLU()

    def forward(self, X):
        H1 = self.relu(self.lin1(X.reshape((-1, self.num_inputs))))
        # 只有在训练模型时才使用dropout
        if self.training == True:
            # 在第一个全连接层之后添加一个dropout层
            H1 = dropout_layer(H1, dropout1)
        H2 = self.relu(self.lin2(H1))
        if self.training == True:
            # 在第二个全连接层之后添加一个dropout层
            H2 = dropout_layer(H2, dropout2)
        out = self.lin3(H2)
        return out


net = Net(num_inputs, num_outputs, num_hiddens1, num_hiddens2)

num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss(reduction='none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
d2l.plt.show()

'''简洁实现'''

net = nn.Sequential(nn.Flatten(),
        nn.Linear(784, 256),
        nn.ReLU(),
        # 在第一个全连接层之后添加一个dropout层
        nn.Dropout(dropout1),
        nn.Linear(256, 256),
        nn.ReLU(),
        # 在第二个全连接层之后添加一个dropout层
        nn.Dropout(dropout2),
        nn.Linear(256, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights);

相关推荐

  1. 练习5-丢弃(包含部分丢弃理解)

    2024-07-16 04:48:04       36 阅读
  2. modbus CRC16校验计算查表

    2024-07-16 04:48:04       43 阅读
  3. 干好工作18

    2024-07-16 04:48:04       28 阅读
  4. Android 13 默认讯飞输入

    2024-07-16 04:48:04       68 阅读
  5. 10 快速排序-左右指针

    2024-07-16 04:48:04       53 阅读
  6. stm8l151,c语言混编汇编,实现16位乘除

    2024-07-16 04:48:04       50 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-16 04:48:04       66 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-16 04:48:04       70 阅读
  3. 在Django里面运行非项目文件

    2024-07-16 04:48:04       57 阅读
  4. Python语言-面向对象

    2024-07-16 04:48:04       68 阅读

热门阅读

  1. 目标检测算法:原理、挑战与应用

    2024-07-16 04:48:04       25 阅读
  2. Deep Layer Aggregation【方法部分解读】

    2024-07-16 04:48:04       25 阅读
  3. Chrome调试工具

    2024-07-16 04:48:04       22 阅读
  4. 探索Mojo编程语言:AI开发者的新宠儿

    2024-07-16 04:48:04       26 阅读
  5. C++:++和--运算符的前置后置如何实现

    2024-07-16 04:48:04       21 阅读
  6. - vuex路由:

    2024-07-16 04:48:04       22 阅读
  7. 数据流通环节如何规避安全风险

    2024-07-16 04:48:04       20 阅读
  8. Linux0715

    Linux0715

    2024-07-16 04:48:04      21 阅读
  9. SQL日期函数

    2024-07-16 04:48:04       24 阅读
  10. Integrated Gradients (Pytorch)refs

    2024-07-16 04:48:04       24 阅读