在结合 warm-up 和余弦退火调度器时,warm-up 阶段的初始学习率和余弦退火调度器的最大学习率不一定需要相同。通常情况下,这两者的学习率可以不同。
在实际应用中,你可以根据具体情况合理设置这两个阶段的学习率,使得模型训练能够更好地收敛和达到较高的性能。一般来说,warm-up 阶段的学习率可以设置相对较低,以帮助模型在初始阶段更稳定地学习参数;而余弦退火阶段的最大学习率可以设置较高,以在训练后期更好地优化模型。示例展示在 PyTorch 中将 warm-up 阶段的学习率与余弦退火阶段的最大学习率设置为不同的值:
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
# 定义模型和优化器
model = YourModel()
optimizer = optim.SGD(model.parameters(), lr=0.1) # 初始学习率
warmup_lr = 0.01
cosine_max_lr = 0.2
warmup_epochs = 5
cosine_epochs = 50
total_epochs = warmup_epochs + cosine_epochs
# 定义学习率调度器,结合 warm-up 和余弦退火
scheduler = CosineAnnealingLR(optimizer, T_max=cosine_epochs, eta_min=0)
# 训练循环
for epoch in range(total_epochs):
# 更新学习率
if epoch < warmup_epochs:
new_lr = warmup_lr + (0.1 - warmup_lr) * (epoch / warmup_epochs)
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
else:
scheduler.step()
# 训练代码
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = loss_function(output, target)
loss.backward()
optimizer.step()
各种学习率曲线:
import torch
from torch.optim.lr_scheduler import *
import torch.nn as nn
from torchvision.models import resnet50
import matplotlib.pyplot as plt
# from lr_scheduler.scheduler import GradualWarmupScheduler
model = resnet50(False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.2)
scheduler1 = LambdaLR(optimizer=optimizer, lr_lambda=lambda epoch: 0.95 ** epoch)
scheduler2 = StepLR(optimizer, step_size=10, gamma=0.1)
scheduler3 = MultiStepLR(optimizer, milestones=[5,10,15,20,25], gamma=0.1)
scheduler4 = ExponentialLR(optimizer, gamma=0.8)
scheduler5 = CosineAnnealingLR(optimizer,T_max=5,eta_min=0.05)
scheduler6 = CyclicLR(optimizer, base_lr=0.01, max_lr=0.2, step_size_up=10, step_size_down=5)
scheduler7 = CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2, eta_min=0.01)
# scheduler8 = GradualWarmupScheduler(optimizer, 1, 5, scheduler2)
plt.figure()
max_epoch = 30
cur_lr_list = []
for epoch in range(max_epoch):
optimizer.step()
scheduler5.step()
cur_lr = optimizer.param_groups[-1]['lr']
cur_lr_list.append(cur_lr)
print('Current lr:', cur_lr)
x_list = list(range(len(cur_lr_list)))
plt.plot(x_list, cur_lr_list)
plt.show()
plt.savefig('gradualwarmupscheduler.png')