在强化学习中,Sarsa和Q-Learning很类似,本次内容将会基于之前所讲的Q-Learning的内容。
目录
- 算法简介
- 更新准则
- 探险者上天堂实战
算法简介
Sarsa决策部分和Q-Learning一抹一样,都是采用Q表的方式进行决策,所以我们会在Q表中挑选values比较大的动作实施在环境中来换取奖赏。但是Sarsa的更新是不一样的
更新准则
和上次一样用小学生写作业为例子,我们会经历写作业的状态s1
,然后再挑选一个带来最大潜在奖励的动作a2
,这样我们就到达了继续写作业的状态s2
,而在这一步没如果你用的是Q-Learning,你会观察一下在s2
上选取哪一个动作会带来最大的奖赏reward
来更新,但是在真正要做决定的时候却不一定会选取到那个带来最大reward
的动作,Q-Learning这一步只是估计了接下来的value。而Sarsa在s2这一步估计的动作就是他接下来要做的动作。所以Q(s1,a2)
现实的计算值我们也会改动,去掉了maxQ
,取而代之的是在S2上我们实实在在选取的a2的Q值。最后像Q-Learning一样,求出现实和估计的差距并更新Q表里的Q(s1,a2)
。
上图就是Sarsa更新的公式。我们可以看到和Q-Learning的不同之处:
- 他在当前的
state
中已经想好了state
对应的action
,而且想好了下一个state_
和下一个action_
(Q-learning还没有想好下一个action_
) - 更新
Q(s,a)
的时候基于的是下一个Q(s_,a_)
(Q-learning基于的是maxQ(s_)
)
这种不同之处使得Sarsa相对于Q-learning显得比较的”胆小“。原因在于
- Q-learning在更新的时候始终都是选择
maxQ
最大化,因为这个maxQ
变得贪婪,不考虑其他非maxQ
的结果。我们可以理解成Q-learning是一种贪婪,大胆,勇敢的算法,对于错误,死亡并不在乎。而Sarsa是一种保守的算法,他在乎每一步的决策,对于错误和死亡比较敏感,这可以在可视化部分看出他们的不同。两种算法都有他们的好处,比如在实际中,如果你比较在乎机器的损害,那么用一种保守的算法,在训练中就可以有效地减少损坏的次数。 - 从另一个角度想,Q-learning更新使用
maxQ
,而Sarsa却要看a_
的值,而a_
的值需要看greedy
的脸色,如果greedy=1
那么a_
就是maxQ
,与Q—Learning在greedy=1
无差别。greedy
值越小,Sarsa越不坚决(选择Q表中大的那个),而是会根据np.random.choice随机选择一个方向,同时也正是因Sarsa多了一项探索的概率,所以才是的Sarsa容易偏离终点,从视觉上看Sarsa有时显得很纠结。正因如此,Sarsa其实在某些程度上显得他很勇敢,因为Sarsa比Q-Learning更有探索精神,也正是这份精神使得Sarsa对终点的渴望不那么果决,饥渴成都要看greedy
的脸色,更具多面性。
探险者上天堂实战
背景
黄色是天堂(reward=1),黑色是地狱(reward=-1)。我们的目标就是让探险者经过自己的多次入“地狱”,最终学会入“天堂”
主模块
首先我们先import两个模块,maze_env
是我们游戏虚拟环境模块,是用python自带的GUI模块tkinter
来编写,具体细节不多赘述,完整代码会放在最后。RL_brain
这个模块是RL的大脑部分,稍后会提及。
from maze_env import Maze
from RL_brain import SarsaTable
下面就是我们的更新部分代码
def update():
for episode in range(100):
# 初始化环境
observation = env.reset()
# Sarsa根据state观测选择行为
action = RL.choose_action(str(observation))
while True:
# 刷新环境
env.render()
# 在环境中采取行为,获得下一个state_(observation_),reward,和终止信号
observation_, reward, done = env.step(action)
# 根据下一个state(observation_)选取下一个action_
action_ = RL.choose_action(str(observation_))
#从(s, a, r, s, a)中学习,更新Q_table的参数
RL.learn(str(observation), action, reward, str(observation_), action_)
# 将下一个的observation_和action_当成对应下一步的参数
observation = observation_
action = action_
if done:
break
# end of game
print('game over')
env.destroy()
if __name__ == "__main__":
#定义环境enc和RL方式
env = Maze()
RL = SarsaTable(actions=list(range(env.n_actions)))
env.after(100, update)
env.mainloop()
RL_brain模块
我们定义一个父类classRL
,然后SarsaTable
作为父类的衍生。
import numpy as np
import pandas as pd
class RL:
#初始化参数
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
self.actions = actions # 行为列表
self.lr = learning_rate #学习率
self.gamma = reward_decay #奖励衰减度
self.epsilon = e_greedy #贪婪度
self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64) #初始化q_table
#选择行为
def choose_action(self, observation):
self.check_state_exist(observation) #检验state是否在q_table中出现
# 贪婪模式
if np.random.uniform() < self.epsilon:
state_action = self.q_table.loc[observation, :]
# 同一个state,可能会有多个相同的Q action value,所以我们乱序一下
action = np.random.choice(state_action[state_action == np.max(state_action)].index)
else:
# 非贪婪模式随机选择action
action = np.random.choice(self.actions)
return action
#学习更新参数
def learn(self, s, a, r, s_):
self.check_state_exist(s_)#同样先检验一下q_table中是否存在S_
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
#下个状态不是终止
q_target = r + self.gamma * self.q_table.loc[s_, :].max()
else:
q_target = r
#更新参数
self.q_table.loc[s, a] += self.lr * (q_target - q_predict)
#检验state是否存在
def check_state_exist(self, state):
if state not in self.q_table.index:
# 如果不存在就插入一组全0数据,当做state的所有action的初始values
self.q_table = self.q_table.append(
pd.Series(
[0]*len(self.actions),
index=self.q_table.columns,
name=state,
)
)
然后我们编写SarsaTable
中learn
也就是更新功能就完成了。
class SarsaTable(RL):
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
def learn(self, s, a, r, s_, a_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, a_] # q_target基于选好的a_而不是Q(s_)的最大值
else:
q_target = r
self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # 更新q_table
最后探险者就可以很轻松的上天堂了!
参考:
https://github.com/MorvanZhou