pytorch中模型训练的学习率动态调整

背景

  在神经网络模型的训练过程中,一般采取梯度下降法来对模型的参数进行更新,其中,学习率 α \alpha α控制着梯度更新的步长(step), α \alpha α越大,意味着下降的越快,到达最优点的速度也越快。学习率较大时,会加速学习,使得模型更容易接近局部或全局最优解。但是在后期会有较大波动,始终难以达到最优。
  因此,我们引入学习率衰减的概念,就是在模型训练初期,使用较大的学习率进行优化,随着迭代次数增加,学习率会逐渐进行减小,保证模型在训练后期不会有太大的波动,从而更加接近最优解,那么,在pytorch中,学习率衰减应该如何实现?

手动设置自动衰减的学习率

  根据进行的epoch的数量,在每一轮对优化器的学习率进行更新。

def adjust_learning_rate(optimizer, epoch, start_lr):
    #每三个epoch衰减一次
    lr = start_lr * (0.1 ** (epoch // 3))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

  这种方法根据自己的逻辑和epoch的数量对学习率进行调整,使用举例:

optimizer = torch.optim.SGD(net.parameters(),lr = start_lr)
for epoch in range(100):
	#手动调整学习率
    adjust_learning_rate(optimizer,epoch,start_lr)
    #查看每一轮的学习率情况
    print("Epoch:{}  Lr:{:.2E}".format(epoch,optimizer.state_dict()['param_groups'][0]['lr']))
    for data,label in traindataloader :
        output = net(data)
        loss = myloss(output,label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

pytorch中的torch.optim.lr_scheduler

  torch.optim.lr_scheduler是pytorch提供的自动调整学习率的方法,基于当前epoch的数值,封装了几种相应的动态学习率调整方法,官方文档optim.lr_scheduler。需要注意的是这种方法对学习率的调整需要应用在优化器参数更新之后,应用方法示例:

optimizer = torch.optim.XXXXXXX()#具体optimizer的初始化
scheduler = torch.optim.lr_scheduler.XXXXXXXXXX()#具体学习率变更策略的初始化
for i in range(epoch):
    for data,label in dataloader:
        out = net(data)
        output_loss = loss(out,label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    scheduler.step()

  下面我们介绍其中几种常用的学习率更新策略。

torch.optim.lr_scheduler.ExponentialLR

  torch.optim.lr_scheduler.ExponentialLR是最简单学习率调整方法,即每一次epoch,lr都乘gamma:

torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1, verbose=False)

  其中,optimizer(optimizer):之前定义好的需要优化的优化器的实例名;gamma(float):学习率衰减的乘法因子,默认为0.1,即每次将学习率乘以0.1;last_epoch(int):默认为-1,为-1时表示将人为设置的学习率设定为调整学习率的基础值lr;verbose:如果为True,每一次更新都会打印一个标准的输出信息,默认为False。

torch.optim.lr_scheduler.StepLR

  torch.optim.lr_scheduler.StepLR是比较常用的等间隔动态调整方法,每经过step_size个epoch,做一次学习率衰减,以gamma值为缩小倍数:

torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1)

  相比于ExponentialLR方法,多了一个step_size(int)参数,即学习率衰减的周期,每经过step_size 个epoch,做一次学习率衰减。

torch.optim.lr_scheduler.MultiStepLR

  torch.optim.lr_scheduler.StepLR根据自己设定的训练阶段调整学习率的方法,一旦达到某一阶段(milestones)时,就可以通过gamma系数降低每个参数组的学习率。可以按照milestones列表中给定的值,进行分阶段式调整学习率:

torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False)

  相比于ExponentialLR方法,多了一个milestones(list)参数,这是一个关于epoch数值的list,表示在达到哪个epoch范围内开始变化,必须是升序排列,使用例子:

optimizer = torch.optim.SGD(net.parameters(), lr=0.001)
#在第2,6,15个epoch调整学习率
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2,6,15], gamma=0.1)

torch.optim.lr_scheduler.ReduceLROnPlateau

  与上述几种基于epoch数目调整学习率的方法不同,该方法根据验证指标的变化的调整学习率。它的原理是:当指标停止改善时,降低学习率。当模型的学习停滞时,训练过程通常会受益于将学习率降低2~10倍。该种调整方法读取一个度量指标,如果在“耐心”期间内没有发现它有所改善,那么就会降低学习率:

torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode= 'rel', cooldown=0, min_1r=0, eps=1e-08)

  其中,optimizer(Optimizer):之前定义好的需要优化的优化器的实例名;mode(str):设置为min或max。当选择min时,代表当度量指标停止下降时,开始减小学习率;当选择max时,代表当度量指标停止上升时,开始减小学习率;factor(float):学习率调整的乘法因子,默认值为0.1;patience(int):可容忍的度量指标没有提升的epoch数目,默认为10。举例说明,当其设置为10时,我们可以容忍10个epoch内没有提升,如果在第11个epoch依然没有提升,那么就开始降低学习率;verbose(bool):如果设置为True,输出每一次更新的信息,默认为False;threshold(float):float类型数据,衡量新的最佳阈值,仅关注重大变化,默认为0.0001;threshold_mode(str):可选str字符串数据,为rel或abs,默认为rel。在rel模式下,如果mode参数为max,则动态阈值(dynamic_threshold)为best*(1+threshold),如果mode参数为min,则动态阈值为best+threshold,如果mode参数为min,则动态阈值为best-threshold;cooldown(int):减少lr后恢复正常操作之前要等待的epoch数,默认为0;min_lr(float):学习率的下界,默认为0;eps(float):学习率的最小变化值。如果调整后的学习率和调整前的差距小于eps的话,那么就不做任何调整,默认为1e-8。

最近更新

  1. TCP协议是安全的吗?

    2024-04-21 03:58:03       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-04-21 03:58:03       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-04-21 03:58:03       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-04-21 03:58:03       18 阅读

热门阅读

  1. web应用使用spring

    2024-04-21 03:58:03       15 阅读
  2. 2024.4.20力扣每日一题——组合总和

    2024-04-21 03:58:03       11 阅读
  3. 游戏中的伤害类型

    2024-04-21 03:58:03       11 阅读
  4. 正则表达式大全,30个正则表达式详细案例

    2024-04-21 03:58:03       17 阅读
  5. 上海计算机学会2023年12月月赛C++丙组T2移动复位

    2024-04-21 03:58:03       13 阅读
  6. 搭建vue3组件库(一):Monorepo项目搭建

    2024-04-21 03:58:03       16 阅读
  7. Docker常见命令学习

    2024-04-21 03:58:03       17 阅读
  8. mac修改/etc/profile导致终端所有命令不可使用

    2024-04-21 03:58:03       15 阅读
  9. CentOS系统上经常使用的一些基本命令

    2024-04-21 03:58:03       13 阅读
  10. android11启动服务

    2024-04-21 03:58:03       13 阅读