ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

【强化学习】SARSA(lambda)与SARSA区别及python代码实现

2022-01-22 17:59:19  阅读:200  来源: 互联网

标签:trace python self actions state SARSA action table lambda


一、概念介绍

单步更新:SARSA是一种单步更新法,每走一步,更新一下自己的行为准则。虽然每一步都在进行更新,但没有获得最终奖励的时候现在所处的的这一步也没获得更新,直到获得最终奖励,获得最终奖励的前一步认为和获得奖励是有关联的。

回合更新:SARSA(lambda)用来代替我们想选择的步数。获得最终奖励后才会进行更新,但是获得奖励的每一步都被认为和获得奖励是有关联的。λ是一个局部搜索的权重值-衰变值,离λ将离越近越重要。λ取0就成了单步更新,λ取1就变形成了回合更新。

二 代码实现

  1. Brain

import numpy as np

import pandas as pd

class RL(object):

    def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):#定义变量

        self.actions = action_space  # a list

        self.lr = learning_rate

        self.gamma = reward_decay

        self.epsilon = e_greedy

        self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)

    def check_state_exist(self, state):#检查状态,有没有缺失

        if state not in self.q_table.index:

            # append new state to q table

            self.q_table = self.q_table.append(

                pd.Series(

                    [0]*len(self.actions),

                    index=self.q_table.columns,

                    name=state,

                )

            )

    def choose_action(self, observation):#选择动作

        self.check_state_exist(observation)

        # action selection

        if np.random.rand() < self.epsilon:

            # 选择最优动作

            state_action = self.q_table.loc[observation, :]

            # some actions may have the same value, randomly choose on in these actions

            action = np.random.choice(state_action[state_action == np.max(state_action)].index)

        else:

            # 随机选择动作

            action = np.random.choice(self.actions)

        return action

    def learn(self, *args):

        pass

# 向后看的方式,离奖励越近越重要

class SarsaLambdaTable(RL):

    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, trace_decay=0.9):

        super(SarsaLambdaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)

      

        self.lambda_ = trace_decay#0-1的值

        self.eligibility_trace = self.q_table.copy()#state-action的表

    def check_state_exist(self, state):

        if state not in self.q_table.index:

            # 增加新的state

            to_be_append = pd.Series(

                    [0] * len(self.actions),

                    index=self.q_table.columns,

                    name=state,

                )

            self.q_table = self.q_table.append(to_be_append)

            # also update eligibility trace

            self.eligibility_trace = self.eligibility_trace.append(to_be_append)

    def learn(self, s, a, r, s_, a_):

        self.check_state_exist(s_)

        q_predict = self.q_table.loc[s, a]

        if s_ != 'terminal':

            q_target = r + self.gamma * self.q_table.loc[s_, a_]  # next state is not terminal

        else:

            q_target = r  # next state is terminal

        error = q_target - q_predict

        # increase trace amount for visited state-action pair

        # Method 1:无封顶

        # self.eligibility_trace.loc[s, a] += 1

        # Method 2:#有封顶

        self.eligibility_trace.loc[s, :] *= 0

        self.eligibility_trace.loc[s, a] = 1

        # Q update

        self.q_table += self.lr * error * self.eligibility_trace

        # decay eligibility trace after update

        self.eligibility_trace *= self.gamma*self.lambda_

2.Test

from maze_env import Maze

from RL_brain import SarsaLambdaTable

def update():

    for episode in range(100):

        # 初始观测值

        observation = env.reset()

        # 基于观测值选择动作

        action = RL.choose_action(str(observation))

        # 开始均为0

        RL.eligibility_trace *= 0

        while True:

            # 更新环境

            env.render()

            # 采取动作得到下一步观测值和奖励

            observation_, reward, done = env.step(action)

            # 基于观测进行动作的选择

            action_ = RL.choose_action(str(observation_))

            # RL learn from this transition (s, a, r, s, a) ==> Sarsa

            RL.learn(str(observation), action, reward, str(observation_), action_)

            # 更新观测和动作

            observation = observation_

            action = action_

            # 当回合结束时进行打断

            if done:

                break

    # 结束游戏

    print('game over')

    env.destroy()

if __name__ == "__main__":

    env = Maze()

    RL = SarsaLambdaTable(actions=list(range(env.n_actions)))

    env.after(100, update)

    env.mainloop()

标签:trace,python,self,actions,state,SARSA,action,table,lambda
来源: https://blog.csdn.net/m0_66111915/article/details/122640961

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有