ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

2021-2-08 tensorflow2.0 Muzero

2021-02-08 12:01:04  阅读:306  来源: 互联网

标签:08 value tensorflow2.0 num ._ action Muzero reward self


参考资料:
[1]ColinFred. 蒙特卡洛树搜索(MCTS)代码详解【python】. 2019-03-23 23:37:09.
[2]饼干Japson 深度强化学习实验室.【论文深度研读报告】MuZero算法过程详解.2021-01-19.
[3]Tangarf. Muzero算法研读报告. 2020-08-31 11:40:20 .
[4]带带弟弟好吗. AlphaGo版本三——MuZero. 2020-08-30.
[5]Google原论文:Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model.
[6]参考GitHub代码1.
[7]参考GitHub代码2.

import tensorflow as tf
import numpy as np

class MuZeroModels(object):
    def __init__(
        self,
        representation_layer_list: "list",
        dynamics_layer_list: "list",
        prediction_layer_list: "list",
    ):

        self.representation = tf.keras.Sequential(
            representation_layer_list,
            name="representation function: obs 1 2 3 ... -> hidden state"
        )

        self.dynamics = tf.keras.Sequential(
            dynamics_layer_list,
            name="dynamics function: hidden state(k) AND action -> hidden state(k+1) AND reward"
        )

        self.prediction = tf.keras.Sequential(
            prediction_layer_list,
            name="prediction function: hidden state -> poliby AND value function"
        )

    @staticmethod
    def loss(
        reward_target,
        value_target,
        polict_target,
        reward_pred,
        valude_pred,
        polict_pred
    ):
        return tf.losses.mean_squared_error(
            y_pred=reward_pred,
            y_true=reward_target
        ) + tf.losses.categorical_crossentropy(
            y_pred=valude_pred,
            y_true=value_target
        ) + tf.losses.categorical_crossentropy(
            y_pred=polict_pred,
            y_true=polict_target
        )

class minmax(object):
    def __init__(self):
        self.maximum = -float("inf")
        self.minimum = float("inf")

    def update(self, value):
        self.maximum = max(self.maximum, value)
        self.minimum = min(self.minimum, value)

    def normalize(self, value):
        if self.maximum > self.minimum:
            return (value - self.minimum) / (self.maximum - self.minimum)
        return value

class TreeNode(object):
    def __init__(
        self,
        parent,
        prior_p,
        hidden_state,
        reward,
        is_PVP: 'bool'=False,
        gamma=0.997
    ):
        self._parent = parent
        self._children = {}
        self._num_visits = 0
        self._Q = 0
        self._U = 0
        self._P = prior_p

        self._hidden_state = hidden_state
        self.reward = reward

        self._is_PVP = is_PVP
        self._gamma = gamma

    def expand(self, action_priorP_hiddenStates_reward):
        '''
        :param action_priors: 元组类型,第一项为执行的动作, 第二项为预测的这个动作的概率, 第三项为 hidden state
        生成新节点扩展树
        '''
        for action, prob, hidden_state, reward in action_priorP_hiddenStates_reward:
            if action not in self._children.keys():
                self._children[action] = TreeNode(
                    parent=self,
                    prior_p=prob,
                    hidden_state=hidden_state,
                    reward=reward,
                    is_PVP=self._is_PVP,
                    gamma=self._gamma
                )

    def select(self, c_puct_1=1.25, c_puct_2=19652):
        '''
        :param c_puct_1: 这里根据论文的值设为1.25
        :param c_puct_2: 这里根据论文的值设为19652
        :return: 选择UCB值最大的节点
        '''
        return max(
            self._children.items(),
            key=lambda node_tuple: node_tuple[1].get_value(c_puct_1, c_puct_2)
        )

    def _update(self, value, reward, minmax):
        '''
        :param reward: 从最后叶子节点 n_l 到当前节点 n_k 回溯的奖励累计(乘上衰变因子)
        :param value: 模型估计的最后的叶子节点 n_l 的值乘上 gamma ^ (l-k)
        注意:此函数无需在类外调用
        '''
        _G = reward + value
        minmax.update(_G)
        _G = minmax.normalize(_G)
        self._Q = (self._num_visits * self._Q + _G) / (self._num_visits + 1)
        self._num_visits += 1

    def backward_update(self, minmax, value, backward_reward=0):
        '''
        :param backward_reward: 从叶子节点回溯的所有奖励乘上衰变因子 gamma 后之和
        :param value: 最后叶子节点估计的值函数
        注意:此函数只用在叶子节点调用, 非叶子节点不调用,值函数之评估最终状态
        '''
        self._update(value, backward_reward, minmax)
        if self._is_PVP:
            all_rewards = self.reward - self._gamma * backward_reward
        else:
            all_rewards = self.reward + self._gamma * backward_reward

        if self._parent:
            self._parent.backward_update(minmax, self._gamma * value, all_rewards)

    def get_value(self, c_puct_1=1.25, c_puct_2=19652):
        '''
        :param c_puct_1: 这里根据论文的值设为1.25
        :param c_puct_2: 这里根据论文的值设为19652
        :return: 计算的值
        注意这里UCB地值计算和 alphazero 不一样
        '''
        self._U = self._P *\
                  (np.sqrt(self._parent._num_visits)/(1 + self._num_visits)) *\
                  (
                    c_puct_1 + np.log(
                      (self._parent._num_visits + c_puct_2 + 1)/c_puct_2)
                  )
        return self._Q + self._U

    def is_leaf(self):
        return self._children == {}

    def is_root(self):
        return self._parent is None

class MCTS(object):
    def __init__(
        self,
        model: 'MuZeroModels',
        observations,
        reward,
        is_PVP: 'bool'=False,
        gamma=0.997,
        num_playout=50,
        c_puct_1=1.25,
        c_puct_2=19652,
    ):

        self._muzero_model = model
        self._minmax = minmax()
        self._root = TreeNode(
            parent=None,
            prior_p=1.0,
            hidden_state=self._muzero_model.representation.predict(observations),
            reward=reward,
            is_PVP=is_PVP,
            gamma=gamma
        )
        self._c_pict_1 = c_puct_1
        self._c_pict_2 = c_puct_2
        self._num_playout = num_playout

    def _playout(self):
        node = self._root
        while True:
            if node.is_leaf():
                break
            _, node = node.select(self._c_pict_1, self._c_pict_2)
        action_probs, value = self._muzero_model.prediction.predict(node._hidden_state)[0]
        action_probs = list(action_probs)

        action_priorP_hiddenStates_reward = []

        for action_prob in action_probs:
            action_num = action_probs.index(action_prob)
            action = 'action:'+str(action_num)

            prob = action_probs[action_num]

            action_num_one_hot = [1 if i == action_num else 0 for i in range(len(action_prob))]
            next_hidden_state, reward = self._muzero_model.dynamics.predict([node._hidden_state, action_num_one_hot])

            action_priorP_hiddenStates_reward.append((action, prob, next_hidden_state, reward))

        node.expand(action_priorP_hiddenStates_reward)

        node.backward_update(minmax=self._minmax, value=value)

    def choice_action(self):
        for _ in range(self._num_playout):
            self._playout()
        actions = []
        visits = []
        for action, node in self._root._children.items():
            actions.append(action)
            visits.append(node._num_visits)

        exp_visits = np.exp(visits)

        return actions, exp_visits / np.sum(exp_visits)

    def __str__(self):
        return "MuZero_MCTS"

class MuZero:
    def __init__(self):
        pass

ps : 代码未完全完成,如有错误欢迎更正。

标签:08,value,tensorflow2.0,num,._,action,Muzero,reward,self
来源: https://blog.csdn.net/weixin_41369892/article/details/113754384

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有