pytorch强化学习(2)——重写DQN

思路

在q-learning当中,Q函数的输入是状态state和action,输出是q-value。

而DQN就是使用神经网络来拟合Q函数,所以从直观上来说,我觉得神经网络的输入应该是状态state和action,输出应该是q-value。

但是,网上绝大多数DQN的代码实现都把state作为网络输入,把所有action的q-value的组合作为网络输出。我觉得这是不直观的、令人费解的,于是我按照自己的想法写了一份DQN代码。

在下面的代码中,神经网络的输入是state和action的连接,若干个浮点数表示state,一个整数表示action。神经网络的输出只有一个元素,代表q-value的值。

代码

env.py

import gym
from DQN_brain import DQN
import matplotlib.pyplot as plt
import numpy

lr = 1e-3  # 学习率
gamma = 0.9  # 折扣因子
epsilon = 0.9  # 贪心系数
n_hidden = 50  # 隐含层神经元个数

env = gym.make("CartPole-v1")
n_states = env.observation_space.shape[0]  # 4
n_actions = env.action_space.n  # 2 动作的个数

dqn = DQN(n_states, n_hidden, n_actions, lr, gamma, epsilon)

if __name__ == '__main__':
    reward_list = []
    for i in range(100):
        # 获取初始环境
        state = env.reset()[0]  # len=4
        total_reward = 0
        done = False
        while True:

            # 获取最优动作
            action = dqn.optimal_action(state)

            # 有一定概率不采取最优动作,而是随机选择一个动作执行,这一点很重要
            if numpy.random.random() > epsilon:
                action = numpy.random.randint(n_actions)

            # 更新环境
            next_state, reward, done, _, _ = env.step(action)
            dqn.learning(state, next_state, action, reward, done)

            # 更新一些变量
            state = next_state
            total_reward += reward

            if done:
                break

        print("第%d回合,total_reward=%f" % (i, total_reward))
        reward_list.append(total_reward)

    # 绘图
    episodes_list = list(range(len(reward_list)))
    plt.plot(episodes_list, reward_list)
    plt.xlabel('Episodes')
    plt.ylabel('Returns')
    plt.title('DQN Returns')
    plt.show()

DQN_brain.py

import torch
from torch import nn, Tensor

class Net(nn.Module):
    # 构造有2个隐含层的网络
    def __init__(self, input_dim: int, n_hidden: int, output_dim: int):
        super().__init__()
        self.network = nn.Sequential(
            torch.nn.Linear(input_dim, n_hidden, dtype=torch.float),
            torch.nn.ReLU(),
            torch.nn.Linear(n_hidden, n_hidden, dtype=torch.float),
            torch.nn.ReLU(),
            torch.nn.Linear(n_hidden, n_hidden, dtype=torch.float),
            torch.nn.ReLU(),
            torch.nn.Linear(n_hidden, output_dim, dtype=torch.float),
        )

    # 前传,直接调用Net对象,其实就是调用forward函数
    def forward(self, x):  # [b,n_states]
        return self.network(x)


class DQN:
    def __init__(self, n_states: int, n_hidden: int, n_actions: int, lr: float, gamma: float, epsilon: float):
        # 属性分配
        self.n_states = n_states  # 状态的特征数
        self.n_hidden = n_hidden  # 隐含层个数
        self.n_actions = n_actions  # 动作数
        self.lr = lr  # 训练时的学习率
        self.gamma = gamma  # 折扣因子,对下一状态的回报的缩放
        self.epsilon = epsilon  # 贪婪策略,有1-epsilon的概率探索

        # 实例化训练网络,网络的输入是state+action,
        # 网络的输出是只有一个元素的一维向量,代表该动作在该状态下的q-value
        self.q_net = Net(self.n_states + 1, self.n_hidden, 1)

        # 优化器,更新训练网络的参数
        self.q_optimizer = torch.optim.Adam(self.q_net.parameters(), lr=lr)

        self.criterion = torch.nn.MSELoss()  # 损失函数

    # 把状态和动作转化为tensor并连接起来
    def _concat_input(self, state: list[float], action: int):
        state_tensor = torch.tensor(state, dtype=torch.float)
        action_tensor = torch.tensor([action], dtype=torch.float)
        return torch.concat([state_tensor, action_tensor])

    # 获取q-value值最大的action
    def optimal_action(self, state: list[float]):
        q_values = torch.tensor([], dtype=torch.float)
        # 获取所有action的q-value
        for action in range(self.n_actions):
            q_values = torch.concat([q_values, self.get_q_value(state, action)])
        # 返回值最大的那个下标,item()函数只能对只有单个元素的tensor使用
        return torch.argmax(q_values).item()

    # 更新网络
    def learning(
            self,
            state: list[float],
            next_state: list[float],
            action: int,
            reward: float,
            done: bool
    ) -> None:
        # 下一状态的最优动作
        next_optimal_action = self.optimal_action(next_state)

        # 当前状态q_value
        q_value = self.get_q_value(state, action)
        # 下一状态q_value
        next_q_value = self.get_q_value(next_state, next_optimal_action)
        # q_target计算
        q_target = reward + self.gamma * next_q_value * (1. - float(done))

        # 计算loss,然后反向传播,然后梯度下降
        loss: Tensor = self.criterion(q_value, q_target)
        self.q_optimizer.zero_grad()
        loss.backward()
        self.q_optimizer.step()

    # 根据状态和动作获取q_value
    def get_q_value(self, state: list[float], action: int) -> Tensor:
        return self.q_net(self._concat_input(state, action))
        # tensor([5.5241], grad_fn=<ViewBackward0>)

相关推荐

  1. pytorch强化学习2)——DQN

    2024-03-10 18:16:03       42 阅读
  2. 强化学习 - Deep Q Network (DQN)

    2024-03-10 18:16:03       68 阅读
  3. 强化学习原理python篇06——DQN

    2024-03-10 18:16:03       56 阅读
  4. 探索Python中的强化学习DQN

    2024-03-10 18:16:03       37 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-03-10 18:16:03       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-03-10 18:16:03       101 阅读
  3. 在Django里面运行非项目文件

    2024-03-10 18:16:03       82 阅读
  4. Python语言-面向对象

    2024-03-10 18:16:03       91 阅读

热门阅读

  1. 概率论与数理统计 P6 条件概率

    2024-03-10 18:16:03       40 阅读
  2. VUE2升级#总结1

    2024-03-10 18:16:03       46 阅读
  3. Pytho爬取音乐

    2024-03-10 18:16:03       38 阅读
  4. 计算机等级考试:信息安全技术 知识的四

    2024-03-10 18:16:03       44 阅读
  5. 非插件方式为wordpress添加一个额外的编辑器

    2024-03-10 18:16:03       38 阅读
  6. 算法练习第十二天|二叉树的递归遍历和迭代遍历

    2024-03-10 18:16:03       42 阅读
  7. 大数据架构

    2024-03-10 18:16:03       38 阅读
  8. typedef 别名的定义和使用

    2024-03-10 18:16:03       48 阅读
  9. springboot 下载 Excel 文件的 Controller 层案例

    2024-03-10 18:16:03       44 阅读
  10. AI辅助研发,引领科技新潮流

    2024-03-10 18:16:03       45 阅读