强化学习实战3:Sarsa 与 Q-Learning 算法求解迷宫问题

前置知识

首先实验环境依然是我们之前说的迷宫环境,然后是一些基本术语,应该都是比较熟悉的:

在这里插入图片描述
在这里插入图片描述

强化学习的算法大概有两类,一类是策略迭代(讲究的是策略 Π ),还有一类是价值迭代,也就是本节要说的内容。

在价值迭代算法的类型中有两个非常重要的算法,即 Sarsa 和 Q-learning。

本节介绍 Sarsa。

Sarsa 名称的由来:取的是一个小 trajectory 的首字母缩写:state action reward state action。

同时这也很形象的展示了 Sarsa 算法的原理,就是根据不断的重复 state action reward state action 这一系列行为从而进行迭代。

Sarsa 维护了一个 Q-table,而 Q-table 一开始是我们进行初始化的,然后通过价值迭代的方式我们跟环境不断的交互然后不断的更新和学习这个 Q-table,也就是说 Q-table 实际上是一个待学习的参数。因此显然对于 Q-table 而言,其行索引表示状态,列索引则表示行为。最后要注意的一点是 Q-table 的值不是概率分布,就是单纯的值(应该就是 reward 值),也就是最大的 value 显然就是我们最应该采取的 action。

基于 Q-table 我们可以评估 action value 和 state value 。

在使用 Sarsa 算法去更新 Q-table 的时候,我们基于的是贝尔曼方程。

而 Q-learning 和 Sarsa 仅仅只是公式上有一个区别,在下一节中将会进行介绍。

Sarsa 算法各部分实现

Q-table 的实现

# 刻画环境:边界 border 和 障碍 barrier
theta_0 = np.array([
    [np.nan, 1, 1, np.nan],  # 表示S0时的策略,即agent不能往上、不能往左走,但可以往右和下走
    [np.nan, 1, np.nan, 1],
    [np.nan, np.nan, 1, 1],
    [1, np.nan, np.nan, np.nan],
    [np.nan, 1, 1, np.nan],
    [1, np.nan, np.nan, 1],
    [np.nan, 1, np.nan, np.nan],
    [1, 1, np.nan, 1],
    # S8 已经是终点了,因此不再需要上下左右到处走了
])

# ---------------------------Q-table---------------------------------
n_states, n_actions = theta_0.shape

# Q-table,状态是离散的(S0 到 S7),动作也是离散的(上下左右)
# 下面这是元素级别的乘法,也就是对位元素相乘
Q = np.random.rand(n_states, n_actions) * theta_0
print(Q)

输出结果如下:

在这里插入图片描述

这个部分就是单纯的随机初始化一下我们的 Q-table,以便于后面进行迭代更新。

ε-greedy 的实现

# -------------------------ε-greedy--------------------------------
# 对于 ε-greedy,其是一个 探索 和 利用 策略
# 将 theta_0 转换为 策略 Π,而 Π 其实就是概率值嘛
def cvt_theta_0_to_pi(theta):
    m, n = theta.shape
    pi = np.zeros((m, n))
    for r in range(m):
        pi[r, :] = theta[r, :] / np.nansum(theta[r, :])
    return np.nan_to_num(pi)


pi_0 = cvt_theta_0_to_pi(theta_0)


# epsilon-ε
# ε-greedy 这样一个策略是用来选取 action 的
# s 表示当前状态,Q 表示 Q-table, eps 是一个超参数, pi_0 是策略
def get_action(s, Q, eps, pi_0):
    # 动作空间是 0 1 2 3
    action_space = list(range(4))
    # eps, explore
    if np.random.rand() < eps:
        action = np.random.choice(action_space, p=pi_0[s, :])
    else:
        # 1-eps, exploit
        action = np.nanargmax(Q[s, :])
    return action

ε-greedy 策略的作用是用来帮助选取 action 的,对于 get_action 函数其具体的解释如下:

在这里插入图片描述
在这里插入图片描述

Sarsa 算法的实现

接下来介绍 Sarsa 算法,也就是进行 QΠ(s,a) 的算法。

在这里插入图片描述

# --------------------------- Sarsa ---------------------------
# gamma 就是 折扣率 参数, eta 是我们需要给定的超参数
def sarsa(s, a, r, s_next, a_next, Q, eta, gamma):
    if s_next == 8:
        Q[s, a] = Q[s, a] + eta * (r - Q[s, a])
    else:
        Q[s, a] = Q[s, a] + eta * (r + gamma * Q[s_next, a_next] - Q[s, a])

解决迷宫问题

代码封装

封装上述代码如下:

# --------------------------- 环境创建 --------------------------
# MazeEnv 类维护着状态,以及 step 函数的返回
class MazeEnv(gym.Env):
    def __init__(self):
        self.state = 0

    def reset(self):
        self.state = 0
        return self.state

    def step(self, action):
        if action == 0:
            self.state -= 3
        elif action == 1:
            self.state += 1
        elif action == 2:
            self.state += 3
        elif action == 3:
            self.state -= 1
        done = False
        reward = 0
        if self.state == 8:
            done = True
            reward = 1
        return self.state, reward, done, {}


# Agent 类基于当前环境中的状态选择动作形成策略
class Agent:
    def __init__(self):
        # action space
        self.actions = list(range(4))
        # 刻画环境:边界 border 和 障碍 barrier
        self.theta_0 = np.array([
            [np.nan, 1, 1, np.nan],  # 表示S0时的策略,即agent不能往上、不能往左走,但可以往右和下走
            [np.nan, 1, np.nan, 1],
            [np.nan, np.nan, 1, 1],
            [1, np.nan, np.nan, np.nan],
            [np.nan, 1, 1, np.nan],
            [1, np.nan, np.nan, 1],
            [np.nan, 1, np.nan, np.nan],
            [1, 1, np.nan, 1],
            # S8 已经是终点了,因此不再需要上下左右到处走了
        ])
        # 策略 Π
        self.pi = self._cvt_theta_0_to_pi()
        # Q-table
        self.Q = np.random.rand(*self.theta_0.shape) * self.theta_0
        # 超参数
        self.eta = 0.1
        # 折扣率
        self.gamma = 0.9
        # ε-greedy 策略的超参数
        self.eps = 0.5

    # 将 theta_0 转换为 策略 Π,而 Π 其实就是概率值嘛
    def _cvt_theta_0_to_pi(self):
        m, n = self.theta_0.shape
        pi = np.zeros((m, n))
        for r in range(m):
            pi[r, :] = self.theta_0[r, :] / np.nansum(self.theta_0[r, :])
        return np.nan_to_num(pi)

    def get_action(self, s):
        # eps, explore 探索
        if np.random.rand() < self.eps:
            action = np.random.choice(self.actions, p=self.pi[s, :])
        else:
            # 1-eps, exploit 利用
            action = np.nanargmax(self.Q[s, :])
        return action

    def sarsa(self, s, a, r, s_next, a_next):
        if s_next == 8:
            self.Q[s, a] = self.Q[s, a] + self.eta * (r - self.Q[s, a])
        else:
            self.Q[s, a] = self.Q[s, a] + self.eta * (r + self.gamma * self.Q[s_next, a_next] - self.Q[s, a])

训练

训练代码如下:

# --------------------------------- 训练 ---------------------------------------
maze = MazeEnv()
agent = Agent()
episode = 0
while True:
    # 下面这行代码会创建一个新的一维数组old_Q,其长度与agent.Q的行数(即状态的数量)相同。
    # old_Q中的每个元素都是对应状态下行(动作)中的最大Q值(忽略NaN)。
    """
    np.nanmax会返回数组中所有非NaN元素的最大值
    axis=1:这个参数指定了np.nanmax函数应该沿着哪个轴来计算最大值。
    在NumPy中,二维数组的轴(axis)是一个维度。
    axis=0表示沿着列(垂直方向)计算,而axis=1表示沿着行(水平方向)计算。
    因此,np.nanmax(agent.Q, axis=1)的意思是对于agent.Q中的每一行(即每个状态对应的所有动作),忽略NaN值,找到最大的Q值。
    """
    old_Q = np.nanmax(agent.Q, axis=1)
    # 每一次 episode 开始时都重置状态为初始状态
    s = maze.reset()
    # 通过 ε-greedy 策略选取一个 action
    a = agent.get_action(s)
    # 记录历史 state-action 对儿
    s_a_history = [[s, np.nan]]
    # 循环跑出每一 episode 的 trajectory
    while True:
        # 将列表末尾第一个 state-action 对儿中的 action 更新一下
        s_a_history[-1][1] = a
        s_next, reward, done, _ = maze.step(a)
        s_a_history.append([s_next, np.nan])
        if done:
            a_next = np.nan
        else:
            a_next = agent.get_action(s_next)
        agent.sarsa(s, a, reward, s_next, a_next)
        if done:
            break;
        else:
            a = a_next
            s = maze.state
    # 这行代码计算了智能体Q表中每个状态下新旧最大Q值之间的绝对差异的总和,并将这个总和赋值给变量update
    """
    np.abs(np.nanmax(agent.Q, axis=1) - old_Q):
    这部分代码首先计算np.nanmax(agent.Q, axis=1)(即每个状态下的新最大Q值)与old_Q(即每个状态下的旧最大Q值)之间的差,
    然后使用np.abs函数取这些差的绝对值。
    这样做的目的是消除正负差异的方向性,只关注差异的大小。
    结果是一个一维数组,其元素表示每个状态下新旧最大Q值之间的绝对差异。
    最后,np.sum函数被用来计算上一步得到的差异数组中所有元素的总和。
    这个总和可以被视为一种“更新量”或“变化量”,它量化了从old_Q到当前agent.Q中每个状态下的最大Q值所发生的总体变化。
    """
    update = np.sum(np.abs(np.nanmax(agent.Q, axis=1) - old_Q))
    episode += 1
    agent.eps /= 2
    print(episode, update, len(s_a_history))
    if episode > 100 or update < 1e-5:
        break

# 最终 Q 表的样子
print("-------------------------- Q-table ---------------------------------")
print(agent.Q)

运行效果如下:

在这里插入图片描述
中间省略…
在这里插入图片描述

从输出结果可以看出,前面七轮是在不断震荡的,从第七轮以后就稳定了,agent 只需要七步就可以走到终点。

另外从最后的 Q-table 也可以看出来,agent 在每一个状态都可以以较大概率选取最优的 action,因此只需要七步就可以走到终点。

可视化展现

都是之前介绍过的代码,直接贴出来了:

# 可视化展现
# 创建一个新的图形对象,并设置其大小为 5x5 英寸
fig = plt.figure(figsize=(5, 5))

# 获取当前图形对象的轴对象
ax = plt.gca()

# 设置坐标轴的范围
ax.set_xlim(0, 3)
ax.set_ylim(0, 3)

# 绘制红色的方格边界,表示迷宫的结构
plt.plot([2, 3], [1, 1], color='red', linewidth=2)
plt.plot([0, 1], [1, 1], color='red', linewidth=2)
plt.plot([1, 1], [1, 2], color='red', linewidth=2)
plt.plot([1, 2], [2, 2], color='red', linewidth=2)

# 在指定位置添加文字标签,表示每个状态(S0-S8)、起点和终点
plt.text(0.5, 2.5, 'S0', size=14, ha='center')
plt.text(1.5, 2.5, 'S1', size=14, ha='center')
plt.text(2.5, 2.5, 'S2', size=14, ha='center')
plt.text(0.5, 1.5, 'S3', size=14, ha='center')
plt.text(1.5, 1.5, 'S4', size=14, ha='center')
plt.text(2.5, 1.5, 'S5', size=14, ha='center')
plt.text(0.5, 0.5, 'S6', size=14, ha='center')
plt.text(1.5, 0.5, 'S7', size=14, ha='center')
plt.text(2.5, 0.5, 'S8', size=14, ha='center')
plt.text(0.5, 2.3, 'Start', ha='center')
plt.text(2.5, 0.3, 'Goal', ha='center')

# 设置坐标轴的显示参数,使得坐标轴不显示
plt.tick_params(axis='both', which='both',
                bottom=False, top=False,
                right=False, left=False,
                labelbottom=False, labelleft=False)

# 在起点位置绘制一个绿色的圆形表示当前位置
line, = ax.plot([0.5], [2.5], marker='o', color='g', markersize=60)


def init():
    line.set_data([], [])
    return (line,)


def animate(i):
    state = s_a_history[i][0]
    x = (state % 3) + 0.5
    y = 2.5 - int(state / 3)
    line.set_data(x, y)


anim = animation.FuncAnimation(fig, animate, init_func=init, frames=len(s_a_history), interval=200, repeat=False)
anim.save('maze_0.mp4')
# 视频观测有时候不太友好,我们还可以使用 IPython 提供的 HTML 的交互式工具
# 由于 PyCharm 不支持显示 IPython 的交互式输出,因此我们这里将 IPython 的输出转换为 HTML 文件再打开
with open('animation.html', 'w') as f:
    f.write(anim.to_jshtml())

可视化的结果是动态的,这里就不展示了,就是 agent 可以很直接快速的找到最终状态。

Q-Learning 算法

实际上,只需要修改 Sarsa 算法代码中的一行,Sarsa 就变成了 Q-Learning 了。

因此这里一起把 Q-Learning 实现了。

重点关注二者区别,主要是算法思想上的不同,二者公式分别如下:

在这里插入图片描述

从公式上可以知道,Sarsa 是策略依赖型的(on-policy),而 Q-Learning 是策略关闭型的(off-policy)。

核心代码如下:

def q_learning(self, s, a, r, s_next):
	if s_next == 8:
		self.Q[s, a] = self.Q[s, a] + self.eta * (r - self.Q[s, a])
	else:
		self.Q[s, a] = self.Q[s, a] + self.eta * (r + self.gamma * np.nanmax(self.Q[s_next, :]) - self.Q[s, a])

然后训练代码也只需要改一行:

# agent.sarsa(s, a, reward, s_next, a_next)
agent.q_learning(s, a, reward, s_next)

其余同上面的 Sarsa 算法一样,效果也是差不多的,这里不再赘述。

相关推荐

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-17 00:38:03       66 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-17 00:38:03       70 阅读
  3. 在Django里面运行非项目文件

    2024-07-17 00:38:03       57 阅读
  4. Python语言-面向对象

    2024-07-17 00:38:03       68 阅读

热门阅读

  1. 算法-双指针

    2024-07-17 00:38:03       22 阅读
  2. Mybatis 之批量处理

    2024-07-17 00:38:03       21 阅读
  3. Spring Boot 面试题及答案整理,最新面试题

    2024-07-17 00:38:03       21 阅读
  4. 【python基础】学习路线

    2024-07-17 00:38:03       21 阅读
  5. HTTP基本原理

    2024-07-17 00:38:03       24 阅读
  6. Git 的基本命令和使用方式

    2024-07-17 00:38:03       22 阅读
  7. 1.3Zygote

    2024-07-17 00:38:03       21 阅读
  8. 精准打击:Conda中conda remove命令的高效使用指南

    2024-07-17 00:38:03       22 阅读
  9. react项目使用EventBus实现登录拦截

    2024-07-17 00:38:03       20 阅读
  10. MySQL 关键字 IN 与 EXISTS 的使用与区别

    2024-07-17 00:38:03       22 阅读
  11. 关于ARP欺骗

    2024-07-17 00:38:03       20 阅读