从PyTorch官方的一篇教程说开去(3.2 - Loss函数或叫Cost函数)

通过刚才的一篇例子中的代码,相信已经非常直观的说明了求极值的GD梯度下降算法。

在 DQN 中,我们使用它来训练一个神经网络。训练最常见的做法,就是表现的好奖励。所以我们关心的是,在一连串的决策中,如何能使累积的奖励最大。

如果我们在状态 𝑠下采取动作 𝑎,预期回报 𝑄(𝑠,𝑎) 可以通过以下方式计算: 𝑄(𝑠,𝑎)=𝑟+𝛾max⁡𝑎′𝑄(𝑠′,𝑎′),其中:

  • 𝑟是采取动作 𝑎a 后立即获得的奖励。
  • 𝛾是折扣因子,它决定了未来奖励的当前价值。
  • max⁡𝑎′𝑄(𝑠′,𝑎′)是在下一个状态 𝑠′下所有可能动作的最大预期回报。

我们见招拆招,一边预测下一步怎么做,一边又根据反馈来的参数,对预测做出调整。预测和反馈就会有一个误差/偏差/成本Cost/损失Loss。

对于我们的 Q 网络,损失函数 𝐿可以定义为实际的 Q 值和预测的 Q 值之间的差异: 𝐿=1/2*∑(𝑠,𝑎,𝑠′,𝑟)∈𝐵 (𝑄(𝑠,𝑎)−(𝑟+𝛾max⁡𝑎′𝑄(𝑠′,𝑎′)))2

这里的𝐵集合,是从重放记忆中随机抽取的一批转换。

损失函数 𝐿衡量的是 Q 网络的预测与实际观察到的预期回报之间的差距。梯度下降法的目标是通过调整网络的权重 𝜃来最小化这个损失: 𝜃←𝜃−𝛼∇𝜃𝐿其中:

  • ∇𝜃𝐿是损失函数 𝐿关于权重 𝜃的梯度。
  • 𝛼是学习率,它控制着权重更新的步长。

这里的2个参数,此处详解一下如何取值:

  • 初始学习率𝛼:

    • 通常从一个较小的值开始,比如 0.01 或 0.001,然后根据训练过程中的观察逐步调整。
    • 使用自适应学习率的优化器(如 Adam 或 RMSprop)可以减少手动调整学习率的需求。
  • 𝛾的取值范围

    • 𝛾通常在 0 到 1 之间。值越接近 1,越重视未来的奖励;值越接近 0,越重视即时的奖励。
    • 决定了未来奖励的当前价值。如果任务需要代理关注长期回报,选择一个较大的𝛾值是合适的。如果任务更注重短期回报,选择一个较小的 𝛾值可能更有效。

相关代码如下,请注意其中注释 - 

// 1 - hyperparameters and utilities
//...
//首先定义了一个损失函数 criterion,这里使用的是 Huber 损失(也称为平滑 L1 损失),
//它在误差较小时表现为均方误差,在误差较大时表现为绝对误差。这使得损失函数对异常值更加鲁棒。
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
//...

// 2 - training loop
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # 转置批次,将 Transition 的批次数组转换为批次的 Transition 数组
    batch = Transition(*zip(*transitions))

    # 计算非最终状态的掩码,并连接批次元素
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # 计算当前状态和动作的 Q 值
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # 计算下一个状态的 V 值
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values

    # 计算预期的 Q 值
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # 计算 Huber 损失
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # 优化模型
    optimizer.zero_grad()  # 清除之前的梯度
    loss.backward()  # 反向传播,计算当前梯度
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)  # 梯度裁剪,防止梯度爆炸
    optimizer.step()  # 更新网络参数

//主训练循环中,每次执行环境步骤后,调用 optimize_model 函数来执行一次梯度下降步骤。
for i_episode in range(num_episodes):
    # 初始化环境并获取其状态
    state, info = env.reset()
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    for t in count():
        action = select_action(state)
        observation, reward, terminated, truncated, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)
        done = terminated or truncated

        if terminated:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)

        # 将转换存储在内存中
        memory.push(state, action, next_state, reward)

        # 转移到下一个状态
        state = next_state

        # 执行模型的一步优化(在策略网络上)
        optimize_model()

        # 目标网络权重的软更新
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key] * TAU + target_net_state_dict[key] * (1 - TAU)
        target_net.load_state_dict(target_net_state_dict)

        if done:
            episode_durations.append(t + 1)
            plot_durations()
            break

这个代码为了突出重点做了删减,不能直接运行,直接运行的请参考前面”代码“文章。

您的进步和反馈是我最大的动力,小伙伴来个三连呗!共勉。

相关推荐

  1. PyTorch -- 最常见损失函数 LOSS 选择

    2024-07-20 17:04:01       24 阅读
  2. “Linux VS Laxcus谁更强”

    2024-07-20 17:04:01       45 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-20 17:04:01       52 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-20 17:04:01       54 阅读
  3. 在Django里面运行非项目文件

    2024-07-20 17:04:01       45 阅读
  4. Python语言-面向对象

    2024-07-20 17:04:01       55 阅读

热门阅读

  1. Node.js 路由

    2024-07-20 17:04:01       18 阅读
  2. JDK版本详解

    2024-07-20 17:04:01       18 阅读
  3. Zookeeper是什么,为什么要用,怎么用?

    2024-07-20 17:04:01       23 阅读
  4. 【c++】用c++类做一个猜数字游戏

    2024-07-20 17:04:01       18 阅读
  5. execjs._exceptions.ProgramError: SyntaxError: 语法错误

    2024-07-20 17:04:01       18 阅读
  6. MySQL自增主键出现不连续的原因?

    2024-07-20 17:04:01       20 阅读
  7. C++案例四:简易记事本程序

    2024-07-20 17:04:01       18 阅读
  8. DNS解析过程

    2024-07-20 17:04:01       17 阅读