来源 | MyEncyclopedia

TD Learning本质上是加了bootstrapping的蒙特卡洛(MC),也是model-free的方法,但实践中往往比蒙特卡洛收敛更快。我们选取OpenAI Gym中经典的CartPole环境来讲解TD。

CartPole OpenAI 环境

如图所示,小车上放了一根杆,杆会根据物理系统定理因重力而倒下,我们可以控制小车往左或者往右,目的是尽可能地让杆保持树立状态。

CartPole OpenAI Gym

CartPole 观察到的状态是四维的float值,分别是车位置,车速度,杆角度和杆角速度。下表为四个维度的值范围。给到小车的动作,即action space,只有两种:0,表示往左推;1,表示往右推。

离散化连续状态

从上所知,CartPole step() 函数返回了4维ndarray,类型为float32的连续状态空间。对于传统的tabular方法来说第一步必须离散化状态,目的是可以作为Q table的主键来查找。下面定义的State类型是离散化后的具体类型,另外 Action 类型已经是0和1,不需要做离散化处理。

State = Tuple[int, int, int, int]Action = int

离散化处理时需要考虑的一个问题是如何设置每个维度的分桶策略。分桶策略会决定性地影响训练的效果。原则上必须将和action以及reward强相关的维度做细粒度分桶,弱相关或者无关的维度做粗粒度分桶。举个例子,小车位置本身并不能影响Agent采取的下一动作,当给定其他三维状态的前提下,因此我们对小车位置这一维度仅设置一个桶(bucket size=1)。而杆的角度和角速度是决定下一动作的关键因素,因此我们分别设置成6个和12个。

以下是离散化相关代码,四个维度的 buckets=(1, 2, 6, 12)。self.q是action value的查找表,具体类型是shape 为 (1, 2, 6, 12, 2)  的ndarray。

class CartPoleAbstractAgent(metaclass=abc.ABCMeta):    def __init__(self, buckets=(1, 2, 6, 12), discount=0.98, lr_min=0.1, epsilon_min=0.1):        self.env = gym.make('CartPole-v0')
        env = self.env        # [position, velocity, angle, angular velocity]        self.dims_config = [(env.observation_space.low[0], env.observation_space.high[0], 1),                            (-0.5, 0.5, 1),                            (env.observation_space.low[2], env.observation_space.high[2], 6),                            (-math.radians(50) / 1., math.radians(50) / 1., 12)]        self.q = np.zeros(buckets + (self.env.action_space.n,))        self.pi = np.zeros_like(self.q)        self.pi[:] = 1.0 / env.action_space.n
    def to_bin_idx(self, val: float, lower: float, upper: float, bucket_num: int) -> int:        percent = (val + abs(lower)) / (upper - lower)        return min(bucket_num - 1, max(0, int(round((bucket_num - 1) * percent))))
    def discretize(self, obs: np.ndarray) -> State:        discrete_states = tuple([self.to_bin_idx(obs[d], *self.dims_config[d]) for d in range(len(obs))])        return discrete_states

train() 方法串联起来 agent 和 env 交互的流程,包括从 env 得到连续状态转换成离散状态,更新 Agent 的 Q table 甚至 Agent的执行policy,choose_action会根据执行 policy 选取action。

def train(self, num_episodes=2000):    for e in range(num_episodes):        print(e)        s: State = self.discretize(self.env.reset())
        self.adjust_learning_rate(e)        self.adjust_epsilon(e)        done = False
        while not done:            action: Action = self.choose_action(s)            obs, reward, done, _ = self.env.step(action)            s_next: State = self.discretize(obs)            a_next = self.choose_action(s_next)            self.update_q(s, action, reward, s_next, a_next)            s = s_next

choose_action 的默认实现为基于现有 Q table 的-greedy 策略。

def choose_action(self, state) -> Action:    if np.random.random() < self.epsilon:        return self.env.action_space.sample()    else:        return np.argmax(self.q[state])


抽象出公共的基类代码 CartPoleAbstractAgent 之后,SARSA、Q-Learning和Expected SARSA只需要复写 update_q 抽象方法即可。

class CartPoleAbstractAgent(metaclass=abc.ABCMeta):    @abc.abstractmethod    def update_q(self, s: State, a: Action, r, s_next: State, a_next: Action):        pass



TD Learning的精髓

MC需要在环境中模拟直至最终结局。若为为t步以后的最终return,则 MC online update 版本更新为:

可以认为向着目标为更新了一小步。

另外一个和Monte Carlo的区别在于一般TD方法保存更精细的Q值,并用Q值来boostrap,而MC一般用V值也可用Q值。


SARSA: On-policy TD 控制

SARSA的命名源于一次迭代产生了五元组 。SARSA利用五个值做 action-value的 online update:

对应的Q table更新实现为:

class SarsaAgent(CartPoleAbstractAgent):
    def update_q(self, s: State, a: Action, r, s_next: State, a_next: Action):        self.q[s][a] += self.lr * (r + self.discount * (self.q[s_next][a_next]) - self.q[s][a])


SARSA 在执行policy 后的Q值更新是对于针对于同一个policy的,完成了一次策略迭代(policy iteration),这个特点区分于后面的Q-learning算法,这也是SARSA 被称为 On-policy 的原因。下面是完整算法伪代码。

SARSA 训练分析

SARSA收敛较慢,1000次episode后还无法持久稳定,后面的Q-learning 和 Expected Sarsa 都可以在1000次episode学习长时间保持不倒的状态。

Q-Learning: Off-policy TD 控制

Q-Learning 是深度学习时代前强化学习领域中的著名算法,它的 online update 公式为:

对应的 update_q() 方法具体实现

class QLearningAgent(CartPoleAbstractAgent):
    def update_q(self, s: State, a: Action, r, s_next: State, a_next: Action):        self.q[s][a] += self.lr * (r + self.discount * np.max(self.q[s_next]) - self.q[s][a])


本质上用现有的Q table中最好的action来bootrap 对应的最佳Q值,推导如下:

Q-Learning 被称为 off-policy 的原因是它并没有完成一次policy iteration,而是直接用已有的 Q 来不断近似

对比下面的Q-Learning 伪代码和之前的 SARSA 版本可以发现,Q-Learning少了一次模拟后的,这也是Q-Learning 中执行policy和预估Q值(即off-policy)分离的一个特征。

Q-Learning 训练分析

Q-Learning 1000次episode就可以持久稳定住。

 

SARSA 改进版 Expected SARSA

Expected SARSA 改进了 SARSA 的地方在于考虑到了在某一状态下的现有策略动作分布,以此来减少variance,加快收敛,具体更新规则为:

注意在实现中,update_q() 不仅更新了Q table,还显示更新了执行policy

class ExpectedSarsaAgent(CartPoleAbstractAgent):
    def update_q(self, s: State, a: Action, r, s_next: State, a_next: Action):        self.q[s][a] = self.q[s][a] + self.lr * (r + self.discount * np.dot(self.pi[s_next], self.q[s_next]) - self.q[s][a])        # update pi[s]        best_a = np.random.choice(np.where(self.q[s] == max(self.q[s]))[0])        n_actions = self.env.action_space.n        self.pi[s][:] = self.epsilon / n_actions        self.pi[s][best_a] = 1 - (n_actions - 1) * (self.epsilon / n_actions)

同样的,Expected SARSA 1000次迭代也能比较好的学到最佳policy。

更多精彩推荐
☞阿里马涛:重新定义云时代的开源操作系统 | 人物志
☞代码攻击破坏设备,炸毁 27 吨发电机的背后
☞算力至上?AI芯片大对决
☞牛!发出中国第一封电子邮件,注册登记域名 CN,中国互联网之父传奇
☞长沙 · 中国1024程序员节盛况空前,500 万程序员线上线下引爆星城

☞“国产操作系统最大难题在于解决「生产关系」” | 人物志
☞一口气看完45个寄存器,CPU核心技术大揭秘
点分享点点赞点在看
Logo

20年前,《新程序员》创刊时,我们的心愿是全面关注程序员成长,中国将拥有新一代世界级的程序员。20年后的今天,我们有了新的使命:助力中国IT技术人成长,成就一亿技术人!

更多推荐