Source code for alf.algorithms.mcts_algorithm

# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Monte-Carlo Tree Search."""
from absl import logging

import numpy as np
import torch
import torch.distributions as td
import torch.nn.functional as F
from typing import Callable, Optional

import alf
from alf import TensorSpec
from alf.data_structures import AlgStep, LossInfo, namedtuple, TimeStep
from alf.algorithms.off_policy_algorithm import OffPolicyAlgorithm
from alf.utils import dist_utils, summary_utils
from alf import nest
from alf.trainers.policy_trainer import Trainer
from alf.utils import common, tensor_utils, action_samplers
from .mcts_models import MCTSModel, ModelOutput

MAXIMUM_FLOAT_VALUE = float('inf')


class _MCTSTrees(object):
    def __init__(self, num_expansions, model_output, known_value_bounds,
                 num_best_children, discount, value_min_max_delta):
        batch_size, branch_factor = model_output.action_probs.shape
        action_spec = dist_utils.extract_spec(model_output.actions, from_dim=2)
        state_spec = dist_utils.extract_spec(model_output.state, from_dim=1)
        if known_value_bounds:
            self.fixed_bounds = True
            self.minimum, self.maximum = known_value_bounds
        else:
            self.fixed_bounds = False
            self.minimum, self.maximum = MAXIMUM_FLOAT_VALUE, -MAXIMUM_FLOAT_VALUE
        self.minimum = torch.full((batch_size, ),
                                  self.minimum,
                                  dtype=torch.float32)
        self.maximum = torch.full((batch_size, ),
                                  self.maximum,
                                  dtype=torch.float32)
        if known_value_bounds:
            self.normalize_scale = 1 / (self.maximum - self.minimum + 1e-30)
            self.normalize_base = self.minimum
        else:
            self.normalize_scale = torch.ones((batch_size, ))
            self.normalize_base = torch.zeros((batch_size, ))
        self.discount = discount
        self.value_min_max_delta = value_min_max_delta

        self.B = torch.arange(batch_size)
        self.root_indices = torch.zeros((batch_size, ), dtype=torch.int64)
        self.branch_factor = branch_factor

        parent_shape = (batch_size, num_expansions)
        children_shape = (batch_size, num_expansions, branch_factor)

        self.visit_count = torch.zeros(parent_shape, dtype=torch.int32)

        # the player who will take action from the current state
        self.to_play = torch.zeros(parent_shape, dtype=torch.int64)
        self.prior = torch.zeros(children_shape)

        # value for player 0
        self.value_sum = torch.zeros(parent_shape)
        self.original_value = torch.zeros_like(
            self.value_sum, dtype=model_output.value.dtype)

        # 0 for not expanded, value in range [0, num_expansions)
        self.children_index = torch.zeros(children_shape, dtype=torch.int64)
        self.model_state = common.zero_tensor_from_nested_spec(
            state_spec, parent_shape)

        # reward for player 0
        self.reward = None
        if isinstance(model_output.reward, torch.Tensor):
            # we use value.dtype because the reward.dtype from the intial_inference
            # can be different from the reward.dtype from the recurrent_inference
            self.reward = torch.zeros(
                parent_shape, dtype=model_output.value.dtype)

        self.action = None
        if isinstance(model_output.actions, torch.Tensor):
            # candidate actions for this state
            self.action = torch.zeros(
                children_shape + action_spec.shape, dtype=action_spec.dtype)

        self.game_over = None
        if isinstance(model_output.game_over, torch.Tensor):
            self.game_over = torch.zeros(parent_shape, dtype=torch.bool)

        shape = ()
        if num_best_children > 1:
            shape = (num_best_children, )
        # value in range [0, branch_factor)
        self.best_child_index = torch.zeros(
            parent_shape + shape, dtype=torch.int64)
        self.ucb_score = torch.zeros(children_shape)

    def update_value_stats(self, nodes, valid=None):
        """Update min max stats for each tree.

        Only valid values are used to update the stats.
        We maintain separate stats for each tree.
        The shapes of ``nodes`` and ``valid`` are [T, B] or [T, B, psims].
        """
        if self.fixed_bounds:
            return
        values = self.discount * self.calc_value(nodes)
        if self.reward is not None:
            values = self.reward[nodes] + values
        if valid is not None:
            invalid = ~valid
            values[invalid] = -MAXIMUM_FLOAT_VALUE
        max = values.max(dim=0)[0]
        if valid is not None:
            values[invalid] = MAXIMUM_FLOAT_VALUE
        min = values.min(dim=0)[0]
        if values.ndim == 3:
            max = max.max(dim=1)[0]
            min = min.min(dim=1)[0]
        self.maximum = torch.max(self.maximum, max)
        self.minimum = torch.min(self.minimum, min)

        # We normalize only when we have set the maximum and minimum values.
        normalize = self.maximum > self.minimum
        self.normalize_scale = torch.where(
            normalize, 1 /
            (self.maximum - self.minimum).clamp(min=self.value_min_max_delta),
            self.normalize_scale)
        self.normalize_base = torch.where(normalize, self.minimum,
                                          self.normalize_base)

    def normalize_value(self, value, batch_index):
        return self.normalize_scale[batch_index] * (
            value - self.normalize_base[batch_index])

    def calc_value(self, nodes):
        return self.value_sum[nodes] / self.visit_count[nodes]

    def get_children(self, parents):
        """
        Get the index tuple of the children of parents. It has one more rank of
        branch_factor than parents
        """
        children = self.children_index[parents]
        return (parents[0].unsqueeze(1), children)

    def get_model_state(self, nodes):
        return nest.map_structure(lambda x: x[nodes], self.model_state)

    def visualize(self, b: int):
        """Visualize the b-th tree as a directed graph.

        The graph is in Graphviz format and can be rendered

        1. Directly in Jupyter Notebook, or
        2. rendered as pdf, with ``trees.visualize(0).render(<path>)``.

        """
        try:
            import graphviz
        except ImportError:
            raise RuntimeError(
                'Need "graphviz" installed if you want to visualize MCTS')

        dot = graphviz.Digraph()
        dot.attr('node', shape='record')

        def _add_node(name: str, properties: dict):
            keys = "|".join(properties.keys())
            values = '|'.join([
                f'{x:.4f}' if type(x) is float else f'{x}'
                for x in properties.values()
            ])
            dot.node(name, label=f'{{{keys}}}|{{{values}}}')

        # Use a queue to BFS traverse the tree
        q = [(b, 0)]
        while len(q) > 0:
            node = q.pop(0)
            visit_count = self.visit_count[node].item()
            reward = self.reward[node].item()
            value = self.calc_value(node).item() + reward
            _add_node(
                str(node[1]), {
                    'visit': visit_count,
                    'value': value,
                    'reward': reward
                })

            children = self.children_index[node]
            prior = self.prior[node]
            # Only the child nodes whose index is not zero and prior > 0 are valid.
            valid = torch.logical_and(children != 0, prior > 0.0)
            children = children[valid].tolist()
            prior = prior[valid].tolist()
            ucb_score = self.ucb_score[node][valid].tolist()

            for c, u, p in zip(children, ucb_score, prior):
                q.append((b, c))
                dot.edge(
                    str(node[1]),
                    str(c),
                    label=f'<ucb {u:.4f} <br/> pri {p:.4f}>')
        return dot


MCTSState = namedtuple(
    "MCTSState",
    ["steps", "pred_state", "action_sampler_state", "next_predicted_reward"],
    default_value=())
MCTSInfo = namedtuple(
    "MCTSInfo", ["candidate_actions", "value", "candidate_action_policy"])


[docs]@alf.configurable class MCTSAlgorithm(OffPolicyAlgorithm): r"""Monte-Carlo Tree Search algorithm. The code largely follows the pseudocode of `Schrittwieser et al. Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model <https://arxiv.org/abs/1911.08265>`_. The pseudocode can be downloaded from `<https://arxiv.org/src/1911.08265v2/anc/pseudocode.py>`_ There are several differences: 1. In this implementation, all values and rewards are for player 0. It seems that the values and rewards in the pseudocode can be either for player 0 or player 1 depending on who is on current turn. It makes reasoning the logic of the code more difficult and error prone. And it indeed seems there is a bug in the pseudocode related to this. More concretely, in the pseudocode, line 524 suggests that the value_sum is relative to a changing player; line 528 suggests that all the rewards along a path are relative to a same player; while line 499 combines the reward and value without considering the player. 2. When calculating UCB score, the pseudocode normalizes value before adding with reward. We normalize after summing reward and value. 3. When calculating UCB score, if the visit count of the node is 0, the value component of the score is 0 in the pseudocode. We use 0.5 instead so that it is not always the lowest score (or highest for player 1) no matter what the outcome of its siblings are. 4. The pseudocode initializes the visit count of root to 0. We initialize it to 1 instead so that prior is not neglected in the first select_child(). This is consistent with how the visit_count of other nodes are initialized. When other nodes are expanded, the immediately subsequenct backup() will make their initial visit_count to be 1. 5. We add a game_over field to ModelOutput to indicate the game is over so that we won't keep expanding over that branch. 6. We add support for using a stochastic policy instead of using UCB to do the search/learn/act. This can be enabled by setting ``act_with_exploration_policy`` ``search_with_exploration_policy``, ``learn_with_exploration_policy`` to True. See `Grill et al. Monte-Carlo tree search as regularized policy optimization <https://arxiv.org/abs/2007.12509>`_ for reference. In addition to the original MuZero paper, we also implemented the method described in the following two paper: 1. `Grill et al. Monte-Carlo tree search as regularized policy optimization <https://arxiv.org/abs/2007.12509>`_ It can be enabled by setting (act/learn/search)_with_exploration_policy 2. `Hubert et. al. Learning and Planning in Complex Action Spaces <https://arxiv.org/abs/2104.06303>`_ It is enabled when SimpleMCTSModel.num_sampled_actions is set. The time spent on tree search is directly related to how many times the tree is expanded. To make it faster, we also support expanding multiple leaves simulaneously. In order to do this, we maintain num_parallel_sims best children for each node in the tree and use them to contruct k=num_paralilel_sims paths. Note the k best children may have duplicates, which is desired because we want to expand the most promising path more often. Depending the value of ``search_with_exploration_policy``, this process is slightly different: * search_with_exploration_policy=True. The k best_children of each node are simply chosen by independently sampling the exploration policy k times. When contructing the search paths, the i-th search path is based on the i-th best child of each node. * search_with_exploration_policy=False. The best child is same the case k=1. The second best child is found by assuming the visit count of the best child and the parent are increased by 1 and applying the UCB criterion again. This is repeated k times to get k best children. Note that this is different from directly selecting the best k childrens based on the original UCB scores. The reason of not doing that is that if the highest score is much bigger than the second highest score, we want to both paths to select the same child. During the process of traversing from the root to contruct k search paths, if several (let's say k') paths are exactly same so far, we will use best k' children of the last node of these k' paths to extend the paths so that The k' children (may contains duplicates) being selected to extend these k' paths are most promising according to the UCB scores. """ def __init__( self, observation_spec, action_spec: alf.BoundedTensorSpec, num_simulations: int, root_dirichlet_alpha: float, root_exploration_fraction: float, pb_c_init: float, # c1 in Appendix B (2) pb_c_base: float, # c2 in Appendix B (2) discount: float, is_two_player_game: bool, visit_softmax_temperature_fn: Callable, model: Optional[MCTSModel] = None, keep_model_pred_state: bool = False, predict_action_sampler=action_samplers.MultinomialSampler(), rollout_action_sampler=action_samplers.MultinomialSampler(), learn_policy_temperature=1.0, reward_spec=TensorSpec(()), expand_all_children: bool = False, expand_all_root_children: bool = False, known_value_bounds=None, value_min_max_delta: float = 1e-30, ucb_break_tie_eps: float = 0., ucb_parent_visit_count_minus_one: bool = False, unexpanded_value_score=0.5, act_with_exploration_policy=False, search_with_exploration_policy=False, learn_with_exploration_policy=False, exploration_policy_type: str = 'rkl', max_unroll_length=1000000, num_parallel_sims=1, checkpoint=None, debug_summaries=False, name="MCTSAlgorithm", ): r""" Args: observation_spec (nested TensorSpec): if the observation is a dictionary, ``MCTSAlgorithm`` will use the following three fields if they are contained in the dictionary: 1. valid_action_mask: a bool Tensor to indicate which actions are allowed. It will be used to mask out invalid actions. If not provided, all possible actions are considered. 2. steps: int32 Tensor to indicate the number of steps since the beginning of the game. If not provided, an internal counter will be used. However, this internal count will not be correct if the algorithm is used to play against human because it is not used to generate all the moves of both players. 3. to_play: int8 Tensor whose elements are 0 or 1 to indicate who is the player to take the action. If not provided, steps % 2 will be used as to_play. action_spec (nested BoundedTensorSpec): representing the actions. num_simulations: the number of simulations per search (calls to model) root_dirichlet_alpha: alpha of dirichlet prior for exploration root_exploration_fraction: noise generated by the dirichlet distribution is combined with the action distribution from the model to be used as the action prior for the children of the root. pb_c_init: c1 of the pUCT rule in Appendix B, equation (2) pb_c_base: c2 of the pUCT rule in Appendix B, equation (2) discount: reward discount factor is_two_player_game (bool): whether this is a two player (zero-sum) game model: the model used by the algorithm. If not provided in the constructor. It should be specified using `set_model` before `predict_step` or `rollout_step` is used. keep_model_pred_state: whether to keep ModelOutput.state.pred_state returned from model.initial_predict as part of the state of this algorithm. If so previous pred_state will be used to call initial_predict. visit_softmax_temperature_fn (Callable): function for calculating the softmax temperature for sampling action based on the visit_counts of the children of the root. :math:`P(a) \propto \exp(visit\_count/t)`. This function is called as ``visit_softmax_temperature_fn(steps)``, where ``steps`` is a vector representing the number of steps in the games. And it is expected to return a float vector of the same shape as ``steps``. predict_action_sampler: available choices include ``CategoricalSeedSampler``, ``EpsilonGreedySampler``, ``MultinomialSampler`` rollout_action_sampler: available choices include ``CategoricalSeedSampler``, ``EpsilonGreedySampler``, ``MultinomialSampler`` learn_policy_temperature (float): transform the policy p found by MCTS by :math:`p^{1/learn_policy_temperature} / Z` as policy target for model learning, where Z is a normalization factor so that the resulting probabilities sum to one. reward_spec (TensorSpec): a rank-1 or rank-0 tensor spec representing the reward(s). expand_all_children (bool): If True, when a new leaf is selected, immediately expand all its children. With this option, the visit count does not truly reflect the quality of a node. Hence it should be used with (act/learn)_with_exploration_policy=True expand_all_root_children (bool): whether to expand all root children before search. This is described in Appendix A of "Learning and Planning in Complex Action Spaces". However, our implementation is different from the paper's. The paper initializes Q(s, a) for root s for all the action being sampled. We expand all sampled action for root s. With this option, the visit count does not truly refect the quality of a node. Hence it should be used with (act/learn)_with_exploration_policy=True. known_value_bounds (tuple|None): known bound of the values. value_min_max_delta: when normalizing the value using the the min and max values, ``(max-min).clamp(min=value_min_max_delta)`` is used as the denominator. ucb_break_tie_eps: add a random number in the range of [0, ucb_break_tie_eps) to the UCB score to choose actions with close UCB score randomly. It is used only if at least one of ``act/search/learn_with_exploration_policy`` is False. ucb_parent_visit_count_minus_one: This option effectively chooses the first child of a parent uniformly, which can increase exploration. unexpanded_value_score (float|str): The value score for an unexpanded child. If 'max'/'min'/'mean', will use the maximum/minimum/mean of the value scores of the expanded siblings. If 'mean_with_parent', will use the mean of the value scores of the expanded siblings and its parent (this is used in ELF OpenGo and EfficientZero). If 'none', when exploration policy is used, will keep the policy for the unexpanded children same as prior; when exporation is not used, 'none' behaves same as 'min'. act_with_exploration_policy (bool): If True, a policy calculated using reverse KL divergence will be used for generate action. search_with_exploration_policy (bool): If True, a policy calculated using reverse KL divergence will be used for tree search. learn_with_exploration_policy (bool): If True, a policy calculated using reverse KL divergence will be used for learning. exploration_policy_type: Type of exploration policy. Must be one of ('rkl', 'kl') max_unroll_length (int): maximal allowed unroll steps when building the search tree. If ``expand_all_children`` is False, the maximal allowed tree depth will be ``max_unroll_length``. Otherwise, the maximal allowed tree depth will be ``max_unroll_length-1`` num_parallel_sims (int): expanding so many leaves at a time for one tree. ``num_simulations`` must be divisable by ``num_parallel_sims``. checkpoint (None|str): a string in the format of "prefix@path", where the "prefix" is the multi-step path to the contents in the checkpoint to be loaded. "path" is the full path to the checkpoint file saved by ALF. Refer to ``Algorithm`` for more details. name (str): the name of the algorithm. """ assert not nest.is_nested( action_spec), "nested action_spec is not supported" assert reward_spec.shape == (), "Only scalar reward is supported" self._num_simulations = num_simulations self._known_value_bounds = known_value_bounds self._value_min_max_delta = value_min_max_delta self._ucb_break_tie_eps = ucb_break_tie_eps self._ucb_parent_visit_count_minus_one = ucb_parent_visit_count_minus_one self._model = None self._pb_c_base = pb_c_base # c2 in Appendix B (2) self._pb_c_init = pb_c_init # c1 in Appendix B (2) self._discount = discount self._is_two_player_game = int(is_two_player_game) self._root_dirichlet_alpha = root_dirichlet_alpha self._root_exploration_fraction = root_exploration_fraction self._visit_softmax_temperature_fn = visit_softmax_temperature_fn self._learn_policy_temperature = learn_policy_temperature assert exploration_policy_type in ('rkl', 'kl') if exploration_policy_type == 'rkl': self._calc_exploration_policy = calculate_exploration_policy else: self._calc_exploration_policy = calculate_kl_exploration_policy self._expand_all_children = expand_all_children self._expand_all_root_children = expand_all_root_children self._num_parallel_sims = num_parallel_sims self._parallel = num_parallel_sims > 1 if expand_all_children: self._max_allowed_depth = max_unroll_length - 1 else: self._max_allowed_depth = max_unroll_length predict_action_sampler_state_spec = () if isinstance(predict_action_sampler, alf.nn.Network): predict_action_sampler_state_spec = predict_action_sampler.state_spec rollout_action_sampler_state_spec = () if isinstance(rollout_action_sampler, alf.nn.Network): rollout_action_sampler_state_spec = rollout_action_sampler.state_spec assert not expand_all_children or num_parallel_sims == 1, ( "num_parallel_sim > 1 is not supported when expand_all_children=True" ) assert num_simulations % num_parallel_sims == 0, ( "num_simulations must " "be divided by num_parallel_sims. Got %s and %s" % (self._num_simulations, num_parallel_sims)) if (expand_all_children or expand_all_root_children) and not ( learn_with_exploration_policy and act_with_exploration_policy): logging.warning( "Consider using (act/learn)_with_exploration_policy" "=True for expand_all_children or expand_all_root_children.") if isinstance(unexpanded_value_score, str): assert unexpanded_value_score in ( 'max', 'min', 'mean', 'mean_with_parent', 'none'), ("Unsupported unexpanded_value_score=%s" % unexpanded_value_score) self._unexpanded_value_score = unexpanded_value_score self._act_with_exploration_policy = act_with_exploration_policy self._search_with_exploration_policy = search_with_exploration_policy self._learn_with_exploration_policy = learn_with_exploration_policy pred_state_spec = () if keep_model_pred_state: pred_state_spec = model.pred_state_spec self._keep_model_pred_state = keep_model_pred_state state_spec = MCTSState( steps=alf.TensorSpec((), dtype=torch.int64), pred_state=pred_state_spec) super().__init__( observation_spec, action_spec, reward_spec=reward_spec, predict_state_spec=state_spec._replace( action_sampler_state=predict_action_sampler_state_spec), rollout_state_spec=state_spec._replace( next_predicted_reward=alf.TensorSpec(()), action_sampler_state=rollout_action_sampler_state_spec), train_state_spec=state_spec._replace( action_sampler_state=rollout_action_sampler_state_spec), checkpoint=checkpoint, debug_summaries=debug_summaries, name=name) self._model = model self._rollout_action_sampler = rollout_action_sampler self._predict_action_sampler = predict_action_sampler
[docs] def set_model(self, model: MCTSModel): """Set the model used by the algorithm.""" self._model = model
@property def discount(self): return self._discount @torch.no_grad() def _step(self, time_step: TimeStep, state: MCTSState, action_sampler): """Predict the action. Args: time_step (TimeStep): the time step structure state (MCTSState): state of the algorithm Returns: AlgStep: - output (Tensor): the selected action - state (MCTSState): the state of the algorithm - info (MCTSInfo): info from the search. It might be used for training the model. """ assert self._model is not None, "Need to call `set_model` before `predict_step`" if isinstance(time_step.observation, dict): valid_action_mask = time_step.observation.get( 'valid_action_mask', None) to_plays = time_step.observation.get('to_play', None) steps = time_step.observation.get('steps', state.steps) else: valid_action_mask = None to_plays = None steps = state.steps model_output = self._model.initial_predict(time_step.observation, state.pred_state) if valid_action_mask is not None: # mask out invalid actions assert model_output.actions == () model_output = model_output._replace( action_probs=model_output.action_probs * valid_action_mask.to(torch.float32)) branch_factor = model_output.action_probs.shape[1] if self._expand_all_children: tree_size = (self._num_simulations + 1) * branch_factor + 1 elif self._expand_all_root_children: tree_size = self._num_simulations + 1 + branch_factor else: tree_size = self._num_simulations + 1 trees = _MCTSTrees( tree_size, model_output, self._known_value_bounds, self._num_parallel_sims, discount=self._discount, value_min_max_delta=self._value_min_max_delta) if self._is_two_player_game and to_plays is None: # We may need the environment to pass to_play and pass to_play to # model because players may not always alternate in some game. # And for many games, it is not always obvious to determine which # player is in the current turn by looking at the board. to_plays = steps % 2 self._max_policy_iterations = 0 roots = (trees.B, trees.root_indices) self._expand_node( trees, 0, to_plays=to_plays, model_output=model_output, dirichlet_alpha=self._root_dirichlet_alpha, exploration_fraction=self._root_exploration_fraction) # paper pseudocode starts visit_count from 0 # we start the root visit_count from 1 so that the first update_best_child # will be based on none-zero ucb_scores trees.visit_count[roots] = 1 trees.value_sum[roots] = model_output.value.to(trees.value_sum.dtype) if not self._expand_all_children: self._build_tree1(trees, to_plays) else: self._build_tree2(trees, to_plays) action_probs, info = self._select_action(trees, steps) pred_state = () if self._keep_model_pred_state: pred_state = model_output.state.pred_state if isinstance(action_sampler, alf.nn.Network): action_id, action_sampler_state = action_sampler( action_probs, state.action_sampler_state) else: action_id, action_sampler_state = action_sampler(action_probs), () if info.candidate_actions != (): action = info.candidate_actions[trees.B, action_id] else: action = action_id if common.is_rollout(): curr_reward = time_step.reward curr_predicted_reward = trees.reward[roots] with alf.summary.scope(self._name): alf.summary.scalar( 'reward', time_step.reward.mean(), average_over_summary_interval=True) alf.summary.scalar( 'reward_next_predicted/value', state.next_predicted_reward.mean(), average_over_summary_interval=True) alf.summary.scalar( 'reward_predicted/value', curr_predicted_reward.mean(), average_over_summary_interval=True) def _summarize_error(name, tgt, pred): error = (pred - tgt).abs() error_z = error[tgt == 0] error_nz = error[tgt != 0] if error_z.numel() > 0: alf.summary.scalar( name + "/error_z", error_z.mean(), average_over_summary_interval=True) else: alf.summary.scalar( name + "/error_z", None, average_over_summary_interval=True) if error_nz.numel() > 0: alf.summary.scalar( name + "/error_nz", error_nz.mean(), average_over_summary_interval=True) else: alf.summary.scalar( name + "/error_nz", None, average_over_summary_interval=True) _summarize_error("reward_next_predicted", curr_reward, state.next_predicted_reward) _summarize_error("reward_predicted", curr_reward, curr_predicted_reward) next_predicted_reward = trees.reward[ trees.B, trees.children_index[roots][trees.B, action_id]] else: next_predicted_reward = () return AlgStep( output=action, state=MCTSState( steps=state.steps + 1, pred_state=pred_state, next_predicted_reward=next_predicted_reward, action_sampler_state=action_sampler_state), info=info)
[docs] def rollout_step(self, time_step: TimeStep, state: MCTSState): return self._step(time_step, state, self._rollout_action_sampler)
[docs] def predict_step(self, time_step: TimeStep, state: MCTSState): return self._step(time_step, state, self._predict_action_sampler)
def _get_best_child_index(self, trees, B, nodes, i): if self._parallel and not self._search_with_exploration_policy: # If several nodes are same, we should select actions from different # entries in best_child_index. So we need to find out which entry # of best_child_index a node should use. For search_with_exploration_policy, # the entries in best_child_index are sampled independently from # the calculated policy, so it is ok to use i'th best_child_index # for i-th node. But for UCB based selection, we want to use # the top entries in best_child_index when ever they are not used. # # For x>y, same_node[b, x, y] indicates where nodes[b, x] and nodes[b, y] are same # For x<=y, same_node[b, x, y] is 0. # # For example, suppose nodes[0] = [2,3,2,3,3], it means that the # current 5 parallel searches on tree 0 are at tree 0's node 2, # node 3, node 2, node 3 and node 3 respectively. According to # our algorithm, under UCB-based selection strategy, we want the # 3 searches at node 3 to go to different "next nodes" (i.e. the # top 3 best children of node 3) respectively. The same goes for the 2 # searches at node 2 (they should go to the top 2 best children # of node 2 respectively). Therefore we want to use i to map nodes # index [2, 3, 2, 3, 3] to children index [0, 0, 1, 1, 2]. # # With this example, same_node[0] will be: # [[0, 0, 0, 0, 0], # [0, 0, 0, 0, 0], # [1, 0, 0, 0, 0], # [0, 1, 0, 0, 0], # [0, 1, 0, 1, 0]] # [B, psims, psims]. `None` stands for unsqueezing at that dim. same_node = (nodes[:, :, None] == nodes[:, None, :]).tril( diagonal=-1) # With the above example, i[0] is [0, 0, 1, 1, 2] # [B, psims] i = (same_node.sum(-1), ) return trees.best_child_index[(B, nodes) + i] def _build_tree1(self, trees, to_plays): """Build the tree by evaluating one node a time.""" roots = (trees.B, trees.root_indices) sim0 = 1 if self._expand_all_root_children: v = self._expand_children(trees, roots, 1, to_plays) branch_factor = trees.branch_factor # Backup the result from the children trees.visit_count[roots] += branch_factor trees.value_sum[roots] += branch_factor * v children = trees.get_children(roots) children = torch.broadcast_tensors(*children) b = torch.cat([roots[0], children[0].reshape(-1)]) i = torch.cat([roots[1], children[1].reshape(-1)]) self._update_best_child(trees, (b, i)) sim0 += trees.branch_factor else: self._update_best_child(trees, roots) psims = self._num_parallel_sims parallel = self._parallel B = trees.B if parallel: B = B.unsqueeze(-1) i = (torch.arange(psims), ) else: i = () max_depth = torch.zeros(B.shape[0], dtype=torch.int64) avg_depth = 0. new_nodes = [] for sim in range(sim0, self._num_simulations + sim0, psims): search_paths, path_lengths, last_to_plays = self._search( trees, to_plays) depth = path_lengths if parallel: depth = path_lengths.max(dim=1)[0] max_depth = torch.maximum(max_depth, depth) avg_depth = avg_depth + path_lengths.to( torch.float32).mean() * psims # [B] or [B, psims] prev_nodes = search_paths[(path_lengths - 2, B) + i] model_state = trees.get_model_state((B, prev_nodes)) # [B] or [B, psims] best_child_index = self._get_best_child_index( trees, B, prev_nodes, i) if trees.action is None: action = best_child_index else: action = trees.action[B, prev_nodes, best_child_index] model_output = self._recurrent_inference(model_state, action) if parallel: sim = torch.arange(sim, sim + psims) self._expand_node( trees, sim, to_plays=last_to_plays, model_output=model_output) # If a child is expanded before, use its existing node # [B] or [B, psims] child_index = trees.children_index[B, prev_nodes, best_child_index] child_index = torch.where((child_index == 0) & (path_lengths > 1), sim, child_index) trees.children_index[B, prev_nodes, best_child_index] = child_index # The reason of not using child_index at the rhs of the following # assignment is that different search_paths may expand the same node # with same action. We want those duplicated expansions point to the # same node id. child_index = trees.children_index[B, prev_nodes, best_child_index] search_paths[(path_lengths - 1, B) + i] = child_index new_nodes.append(child_index) self._backup( trees, search_paths=search_paths, path_lengths=path_lengths, values=model_output.value) if self._debug_summaries and alf.summary.should_record_summaries(): if not self._parallel: new_nodes = [x.unsqueeze(-1) for x in new_nodes] B = B.unsqueeze(-1) new_nodes = torch.cat(new_nodes, dim=1) node_i = B * trees.visit_count.shape[1] + new_nodes unique_node_i = torch.unique(node_i[:]) with alf.summary.scope(self._name): alf.summary.scalar("avg_tree_size", unique_node_i.numel() / B.shape[0]) if self._num_simulations > 0: avg_depth = avg_depth / self._num_simulations self._summarize(max_depth - 1, avg_depth - 1) def _summarize(self, max_unroll_length, avg_unroll_length): if self._debug_summaries and alf.summary.should_record_summaries(): with alf.summary.scope(self._name): alf.summary.scalar("max_policy_iterations", self._max_policy_iterations) alf.summary.scalar("max_unroll_length", max_unroll_length.max()) alf.summary.scalar("avg_max_unroll_length", max_unroll_length.to(torch.float32).mean()) alf.summary.scalar("avg_unroll_length", avg_unroll_length) def _recurrent_inference(self, model_state, action): if not self._parallel: return self._model.recurrent_inference(model_state, action) batch_size, psims = action.shape[:2] def _flatten(x): return x.reshape(-1, *x.shape[2:]) def _reshape(x): return x.reshape(batch_size, psims, *x.shape[1:]) model_state = nest.map_structure(_flatten, model_state) action = _flatten(action) model_output = self._model.recurrent_inference(model_state, action) # reshape to [batch_size, self._num_parallel_sims, ...] model_output = model_output._replace( action_distribution=(), game_over_logit=()) model_output = nest.map_structure(_reshape, model_output) return model_output def _build_tree2(self, trees, to_plays): """Build the tree by evaluating all the children of one node a time.""" # expand the root node and backup the values of the children to root. branch_factor = trees.branch_factor roots = (trees.B, trees.root_indices) v = self._expand_children(trees, roots, 1, to_plays) trees.visit_count[roots] += branch_factor trees.value_sum[roots] += branch_factor * v self._update_best_child(trees, roots) max_depth = torch.zeros(trees.B.shape[0], dtype=torch.int64) avg_depth = 0. for sim in range(1, self._num_simulations + 1): search_paths, path_lengths, last_to_plays = self._search( trees, to_plays) max_depth = torch.maximum(max_depth, path_lengths) avg_depth = avg_depth + path_lengths.to(torch.float32).mean() prev_nodes = search_paths[path_lengths - 1, trees.B] if trees.game_over is None: values = self._expand_children(trees, (trees.B, prev_nodes), 1 + sim * branch_factor, last_to_plays) else: # Do not expand game_over nodes. values = torch.zeros(trees.B.shape[0]) game_over = trees.game_over[trees.B, prev_nodes] if not game_over.all(): not_over = ~game_over nodes = (trees.B[not_over], prev_nodes[not_over]) values[not_over] = self._expand_children( trees, nodes, 1 + sim * branch_factor, last_to_plays[not_over] if self._is_two_player_game else None) self._backup( trees, search_paths=search_paths, path_lengths=path_lengths, values=values, count=branch_factor) avg_depth = avg_depth / self._num_simulations self._summarize(max_depth, avg_depth) def _print_tree(self, trees: _MCTSTrees, b): """Helper function to visualize the b-th search tree.""" nodes = [((b, 0), 0, 0)] np.set_printoptions(precision=4) def _get_item(nested, i): return alf.nest.map_structure(lambda x: x[i], nested) def _print_policy(policy): if policy.numel() <= 4: return policy policy = policy.cpu() indices = torch.argsort(policy, descending=True) s = ' '.join( ["%s:%0.3f" % (int(i), float(policy[i])) for i in indices[:4]]) return '[' + s + ']' while len(nodes) > 0: node, action, depth = nodes.pop(0) print( '| ' * depth + "a=%s" % action, "node=%-3d" % node[1], "state=%-4d" % (hash(str(_get_item(trees.model_state, node))) % 10000), '--' * (self._max_allowed_depth - depth) + "original_value=%-7.4f" % trees.original_value[node].cpu().numpy(), "v=%-7.4f" % trees.calc_value(node).cpu().numpy(), "r=%-7.4f" % trees.reward[node].cpu().numpy(), "prior=%s" % _print_policy(trees.prior[node]), "policy=%s" % _print_policy( self._calculate_policy(trees, (torch.tensor([node[0]]), torch.tensor([node[1]])))[0]), "visit_count=%s" % trees.visit_count[node].item()) children = list(trees.children_index[node]) prior = list(trees.prior[node]) new_nodes = [((b, int(c)), a, depth + 1) for a, (c, p) in enumerate(zip(children, prior)) if c != 0 and p != 0] nodes = new_nodes + nodes def _search(self, trees, to_plays): """ Returns: tuple: - search_paths: [T, B] or [T, B, psims] int64 Tensor, where T is the max length of the search paths. For paths whose length is shorter than T, search_path is padded with the last node index of that path. - path_lengths: [B] or [B, psims] Tensor, length of each search path - last_to_plays: [B] or [B, psims], to_play for the last node of each path """ children_index = trees.children_index game_over = trees.game_over psims = self._num_parallel_sims search_paths = [] B = trees.B nodes = trees.root_indices i = () if self._parallel: # If the p-th search path of the b-th tree is at node n, it will be # extended by the child indicated by best_child_index[b, n, i[p]]. if self._search_with_exploration_policy: i = (torch.arange(psims), ) B = B.unsqueeze(-1) nodes = trees.root_indices.unsqueeze(-1).expand(-1, psims) if self._is_two_player_game: to_plays = to_plays.unsqueeze(-1).expand_as(nodes) # [B] or [B, psims] path_lengths = torch.ones_like(nodes) search_paths = [nodes] if game_over is not None: done = game_over[B, nodes] else: done = torch.zeros_like(nodes, dtype=torch.bool) depth = 0 while not torch.all(done) and depth < self._max_allowed_depth: depth += 1 best_child_index = self._get_best_child_index(trees, B, nodes, i) # [B] or [B, psims] nodes = torch.where(done, nodes, children_index[B, nodes, best_child_index]) path_lengths[~done] += 1 if self._is_two_player_game: to_plays = torch.where(done, to_plays, self._is_two_player_game - to_plays) if self._expand_all_children: # no children means node is not expanded done = (trees.children_index[B, nodes] == 0).all(dim=-1) else: # nodes == 0 means unexpanded node done = nodes == 0 if game_over is not None: done = game_over[B, nodes] | done search_paths.append(nodes) search_paths = torch.stack(search_paths) return search_paths, path_lengths, to_plays def _select_action(self, trees, steps): roots = (trees.B, trees.root_indices) if self._act_with_exploration_policy or self._learn_with_exploration_policy: policy = self._calculate_policy(trees, roots) if not self._act_with_exploration_policy or not self._learn_with_exploration_policy: children = trees.get_children(roots) visit_counts = trees.visit_count[children] visit_counts[children[1] == 0] = 0 if self._act_with_exploration_policy: probs = policy else: probs = visit_counts t = self._visit_softmax_temperature_fn(steps).unsqueeze(-1) probs = F.softmax((probs + 1e-30).log() / t, dim=1) probs = probs * (trees.prior[roots] > 0).to(torch.float32) if trees.action is not None: candidate_actions = trees.action[roots] else: candidate_actions = () if not self._learn_with_exploration_policy: parent_visit_count = (trees.visit_count[roots] - 1.0).unsqueeze(-1) policy = visit_counts / parent_visit_count if self._learn_policy_temperature != 1.0: policy = F.softmax( (policy + 1e-30).log() / self._learn_policy_temperature, dim=1) info = MCTSInfo( candidate_actions=candidate_actions, value=trees.calc_value(roots), candidate_action_policy=policy, ) if self._debug_summaries and alf.summary.should_record_summaries(): with alf.summary.scope(self._name): prior = trees.prior[roots] prior = prior + 1e-30 policy = policy + 1e-30 log_ratio = (prior / policy).log() rkld = (prior * log_ratio).sum(-1) kld = -(policy * log_ratio).sum(-1) policy_entropy = -(policy * policy.log()).sum(-1) prior_entropy = -(prior * prior.log()).sum(-1) summary_utils.add_mean_hist_summary("rkld", rkld) summary_utils.add_mean_hist_summary("kld", kld) summary_utils.add_mean_hist_summary("policy_entropy", policy_entropy) summary_utils.add_mean_hist_summary("prior_entropy", prior_entropy) summary_utils.add_mean_hist_summary("value", info.value) for i in range(policy.shape[1]): summary_utils.add_mean_hist_summary( 'policy%s' % i, policy[:, i]) return probs, info def _update_best_child(self, trees: _MCTSTrees, parents): if self._search_with_exploration_policy: self._sample_child(trees, parents) elif not self._parallel: self._ucb_child(trees, parents) else: self._ucb_child_parallel(trees, parents) def _ucb_child(self, trees: _MCTSTrees, parents): """Get child using UCB score.""" ucb_scores = self._ucb_score(trees, parents) trees.ucb_score[parents] = ucb_scores best_child_index = ucb_scores.argmax(dim=1) trees.best_child_index[parents] = best_child_index def _ucb_child_parallel(self, trees: _MCTSTrees, parents): """Get child using UCB score.""" prior = trees.prior[parents] value_score, child_visit_count = self._value_score(trees, parents) value_score[prior == 0] = -MAXIMUM_FLOAT_VALUE parent_visit_count = trees.visit_count[parents].unsqueeze(-1) ucb_scores = self._ucb_score1(parent_visit_count, child_visit_count, prior, value_score) trees.ucb_score[parents] = ucb_scores best_child_index = ucb_scores.argmax(dim=1) trees.best_child_index[parents + (0, )] = best_child_index n = torch.arange(child_visit_count.shape[0]) for i in range(1, self._num_parallel_sims): child_visit_count[n, best_child_index] += 1 parent_visit_count += 1 ucb_scores = self._ucb_score1( parent_visit_count, child_visit_count, prior, value_score) best_child_index = ucb_scores.argmax(dim=1) trees.best_child_index[parents + (i, )] = best_child_index def _ucb_score(self, trees, parents): prior = trees.prior[parents] value_score, child_visit_count = self._value_score(trees, parents) value_score[prior == 0] = -MAXIMUM_FLOAT_VALUE parent_visit_count = trees.visit_count[parents].unsqueeze(-1) return self._ucb_score1(parent_visit_count, child_visit_count, prior, value_score) def _ucb_score1(self, parent_visit_count, child_visit_count, prior, value_score): if self._ucb_parent_visit_count_minus_one: parent_visit_count = parent_visit_count - 1 pb_c = torch.log1p( (parent_visit_count + 1) / self._pb_c_base) + self._pb_c_init pb_c = pb_c * torch.sqrt(parent_visit_count.to( torch.float32)) / (child_visit_count + 1.) prior_score = pb_c * prior if self._ucb_break_tie_eps > 0: prior_score += self._ucb_break_tie_eps * torch.rand_like( prior_score) return prior_score + value_score def _value_score(self, trees: _MCTSTrees, parents): children = trees.get_children(parents) child_visit_count = trees.visit_count[children] unexpanded = children[1] == 0 child_visit_count[unexpanded] = 0 value_score = self._discount * trees.calc_value(children) if trees.reward is not None: value_score = value_score + trees.reward[children] value_score = trees.normalize_value(value_score, children[0]) if not isinstance(self._unexpanded_value_score, str): value_score[unexpanded] = self._unexpanded_value_score if self._is_two_player_game: value_score = value_score * ( (trees.to_play[parents] == 0) * 2 - 1).unsqueeze(-1) if self._unexpanded_value_score == 'max': # Set the value of unexpanded children to maximum so that the probability # of it being selected is proportional to its prior. value_score[unexpanded] = -MAXIMUM_FLOAT_VALUE max_value_score = value_score.max(dim=1, keepdim=True)[0] # max_value_score is not defined if none of the children is expanded. # set it to 0 max_value_score[max_value_score == -MAXIMUM_FLOAT_VALUE] = 0. value_score = torch.where(unexpanded, max_value_score, value_score) elif self._unexpanded_value_score in ('min', 'none'): value_score[unexpanded] = MAXIMUM_FLOAT_VALUE min_value_score = value_score.min(dim=1, keepdim=True)[0] # min_value_score is not defined if none of the children is expanded. # set it to 0 min_value_score[min_value_score == MAXIMUM_FLOAT_VALUE] = 0. value_score = torch.where(unexpanded, min_value_score, value_score) elif self._unexpanded_value_score == 'mean': value_score[unexpanded] = 0. n = (~unexpanded).sum(dim=1, keepdim=True) + 1e-30 mean_value_score = value_score.sum(dim=1, keepdim=True) / n value_score = torch.where(unexpanded, mean_value_score, value_score) elif self._unexpanded_value_score == 'mean_with_parent': pvalue_score = trees.calc_value(parents) pvalue_score = trees.normalize_value(pvalue_score, parents[0]) value_score[unexpanded] = 0. n = (~unexpanded).sum(dim=1, keepdim=True) mean_value_score = (pvalue_score.unsqueeze(1) + value_score.sum( dim=1, keepdim=True)) / (n + 1) value_score = torch.where(unexpanded, mean_value_score, value_score) return value_score, child_visit_count def _calculate_policy(self, trees: _MCTSTrees, parents): parent_visit_count = trees.visit_count[parents].unsqueeze(-1) if self._ucb_parent_visit_count_minus_one: parent_visit_count = parent_visit_count - 1 c = torch.log1p( (parent_visit_count + 1) / self._pb_c_base) + self._pb_c_init c = c / parent_visit_count.to(torch.float32).sqrt() prior = trees.prior[parents] value_score, child_visit_count = self._value_score(trees, parents) if self._unexpanded_value_score != 'none': policy, iterations = self._calc_exploration_policy( value_score, prior, c) else: # For 'none', we keep the policy for the unexpanded children same # as prior. expanded = child_visit_count != 0 expanded_prior = prior * (expanded.to(torch.float32) + 1e-30) sum_expanded = expanded_prior.sum(dim=1, keepdim=True) expanded_prior = expanded_prior / sum_expanded policy, iterations = self._calc_exploration_policy( value_score, expanded_prior, c) policy = torch.where(expanded, policy * sum_expanded, prior) if self._ucb_parent_visit_count_minus_one: # "parent_visit_count == 0" will get c=inf and policy will be equal to # prior. But according to ucb calculation (_ucb_score1()), we should # use uniform distribution in this case. uniform = (parent_visit_count == 0).unsqueeze(1) policy[uniform] = 1 / policy.shape[1] self._max_policy_iterations = max(self._max_policy_iterations, iterations) return policy def _sample_child(self, trees: _MCTSTrees, parents): """Get child by sampling from exploration policy.""" if torch.are_deterministic_algorithms_enabled and self._parallel: # when using parallel expansions, there can be duplicated nodes in # parents and will cause the assignment at the end of the function # non-deterministic. num_expansions = trees.visit_count.shape[1] nodes_i = parents[0] * num_expansions + parents[1] nodes_i = nodes_i.unique() parents = (nodes_i // num_expansions, nodes_i % num_expansions) policy = self._calculate_policy(trees, parents) # We need replacement so that the promising children are searched more # often and value estimation is unbiased. action = torch.multinomial( policy, num_samples=self._num_parallel_sims, replacement=True) if not self._parallel: action = action.squeeze(1) trees.best_child_index[parents] = action def _backup(self, trees: _MCTSTrees, search_paths, path_lengths, values, count=1): r""" Notation: vsum: trees.value_sum vc: trees.visit_count v: value T: path_length c: count t: 0 <= t <= T-1 The following shows the computation for each search path: G_t = \sum_{s=t+1}^{T-1} \gamma^{s-t-1} r_s + gamma^{T-t-1} v vsum_t += c G_t vc_t += c Args: search_paths: [T, B] or [T, B, psims] path_length: [B] or [B, psims] values: [B] or [B, psims] """ psims = self._num_parallel_sims B = trees.B.unsqueeze(0) T = search_paths.shape[0] depth = torch.arange(T).unsqueeze(-1) parallel = self._parallel if parallel: B = B.unsqueeze(-1) depth = depth.unsqueeze(-1) i = (torch.arange(psims), ) else: i = () if trees.reward is not None: reward = trees.reward[B, search_paths] reward[depth > path_lengths] = 0. # [T+1, B] or [T+1, B, psims] reward = tensor_utils.tensor_extend_zero(reward) reward[(path_lengths, B[0]) + i] = values # [T+1, 1] discounts = (self._discount**torch.arange( T + 1, dtype=torch.float32)).unsqueeze(-1) if parallel: # [T+1, 1, 1] discounts = discounts.unsqueeze(-1) # discounted_return[t] = discount^t * reward[t] discounted_return = reward * discounts # discounted_return[t] = \sum_{s=t}^T discount^s * reward[s] discounted_return = discounted_return.flip(0).cumsum(dim=0).flip(0) # discounted_return[t] = \sum_{s=t}^T discount^(s-t) * reward[s] discounted_return = discounted_return / discounts discounted_return = discounted_return[1:] else: # [T, 1] steps = torch.arange(1, T + 1, dtype=torch.float32).unsqueeze(-1) if parallel: # [T, 1, psims] steps = steps.unsqueeze(-1) # [T, B] or [T, B, psims] discounts = self._discount**(path_lengths.unsqueeze(0) - steps) discounted_return = values.unsqueeze(0) * discounts # [T, B] or [T, B, psims] valid = depth < path_lengths value_sum_delta = count * discounted_return[valid] nodes = (B.expand_as(search_paths)[valid], search_paths[valid]) if not parallel: trees.visit_count[nodes] += count trees.value_sum[nodes] += value_sum_delta else: if not torch.are_deterministic_algorithms_enabled: # This branch is not deterministic because of the possible duplications # in nodes_i. nodes_i = nodes[0] * trees.visit_count.shape[1] + nodes[1] trees.visit_count.view(-1).scatter_add_( 0, nodes_i, torch.tensor( [count], dtype=trees.visit_count.dtype).expand_as(nodes_i)) trees.value_sum.view(-1).scatter_add_(0, nodes_i, value_sum_delta) else: depth = depth.squeeze(-1) TB = trees.B.unsqueeze(0).expand(T, -1) for i in range(psims): # [T, B] vld = valid[..., i] nd = (TB[vld], search_paths[..., i][vld]) trees.visit_count[nd] += count trees.value_sum[ nd] += count * discounted_return[..., i][vld] trees.update_value_stats((B, search_paths), valid) self._update_best_child(trees, nodes) def _expand_children(self, trees: _MCTSTrees, nodes, n, to_plays): """Expand all children of nodes. Args: nodes: tuple(trees.B, node_index). Nodes to be expanded. n (int): n-th expansion, zero-based to_plays: bool[batch_size], to_plays for nodes Returns: Tensor with shape of nodes. The weighted (prior as weight) average of the bootstrapped return from the children. """ batch_size = nodes[0].shape[0] branch_factor = trees.branch_factor model_state = trees.get_model_state(nodes) def _repeat(x): return x.repeat_interleave(trees.branch_factor, dim=0) model_state = nest.map_structure(_repeat, model_state) if trees.action is None: action = torch.arange(branch_factor)[None, :].expand( batch_size, -1) else: action = trees.action[nodes] action = action.reshape(-1, *action.shape[2:]) model_output = self._model.recurrent_inference(model_state, action) def _reshape(x): return x.reshape(batch_size, branch_factor, *x.shape[1:]) # reshape to [batch_size, branch_factor, ...] model_output = model_output._replace( action_distribution=(), game_over_logit=()) model_output = nest.map_structure(_reshape, model_output) # If a node is expanded before, use its existing children children_index0 = trees.children_index[nodes] children_index = torch.where(children_index0 == 0, torch.arange(n, n + branch_factor), children_index0) trees.children_index[nodes] = children_index children = (nodes[0].unsqueeze(-1), children_index) trees.value_sum[children] = model_output.value.to( trees.value_sum.dtype) trees.original_value[children] = model_output.value trees.visit_count[children] += 1 trees.update_value_stats((children[0].t(), children[1].t())) if self._is_two_player_game: assert to_plays.shape == nodes[0].shape trees.to_play[children] = 1 - to_plays.unsqueeze(-1) if trees.game_over is not None: trees.game_over[children] = model_output.game_over def _set_tree_state(ts, s): ts[children] = s nest.map_structure(_set_tree_state, trees.model_state, model_output.state) if trees.reward is not None: trees.reward[children] = model_output.reward if trees.action is not None: trees.action[children] = model_output.actions trees.prior[children] = model_output.action_probs value = self._discount * model_output.value if trees.reward is not None: value += model_output.reward value = (value * trees.prior[nodes]).sum(dim=-1) return value def _expand_node( self, trees: _MCTSTrees, n, # n-th expansion, zero-based to_plays, model_output: ModelOutput, dirichlet_alpha=None, exploration_fraction=0.): if self._is_two_player_game: trees.to_play[:, n] = to_plays if trees.game_over is not None: trees.game_over[:, n] = model_output.game_over def _set_tree_state(ts, s): ts[:, n] = s nest.map_structure(_set_tree_state, trees.model_state, model_output.state) if trees.reward is not None: # model_output.reward.dtype may be different from initial_inference trees.reward[:, n] = model_output.reward.to(trees.reward.dtype) if trees.action is not None: trees.action[:, n] = model_output.actions prior = model_output.action_probs if exploration_fraction > 0.: batch_size = model_output.action_probs.shape[0] noise_dist = td.Dirichlet( dirichlet_alpha * torch.ones(trees.branch_factor)) noise = noise_dist.sample((batch_size, )) noise = noise * (prior != 0) noise = noise / noise.sum(dim=1, keepdim=True) prior = exploration_fraction * noise + ( 1 - exploration_fraction) * prior trees.prior[:, n] = prior trees.original_value[:, n] = model_output.value
[docs]def calculate_exploration_policy(value, prior, c, tol=1e-6): r"""Calculate exploration policy. The policy is based on `Grill et al. Monte-Carlo tree search as regularized policy optimization <https://arxiv.org/abs/2007.12509>`_ Notation: q: prior policy p: sampling probability v: value The exploration policy is found by minimizing the following: .. math:: p = \arg\min_p \left[ -E_p(v) + c KL(q\|p) \right] which leads to the following solution: .. math:: p_i = c\frac{q_i}{\alpha - v_i} where :math:`\alpha \ge \max_i(v_i)` is such that :math:`\sum_i p_i = 1` To make the solving numerically more stable and efficient, we reparameterize the problem to the following: .. math:: \begin{array}{ll} & v^* = \max_i v_i \\ & \alpha = v^* + c \beta \\ & u_i = \frac{v_i - v^*}{c} \\ & p_i = \frac{q_i}{\beta - u_i} \\ \end{array} With this reparametrization, we need to find :math:`\beta>0` s.t. .. math:: \sum_i \frac{q_i}{\beta - u_i} = 1 We use Newton's method to update :math:`\beta` iteratively: .. math:: \beta \leftarrow \beta - \frac{f(\beta)}{f'(\beta)} = \beta + \frac{\sum_i \frac{q_i}{\beta - v_i} - 1}{\sum_i \frac{q_i}{(\beta - v_i)^2}} where :math:`f(\beta) = \sum_i \frac{q_i}{\beta - u_i} - 1` and :math:`f'(\beta)` is the derivative of :math:`f(\beta)`. Since :math:`f(\beta)` is convex, starting the iteration with a :math:`\beta` s.t. :math:`f(\beta) > 0` gaurantees the convergence. In practice, we find that about 10 iterations can reach tolerance of 1e-6. Newton's method is much faster than binary search. Args: value (Tensor): [N, K] Tensor prior (Tensor): [N, K] Tensor c (Tensor): [N, 1] Tensor tol (float): Desired acurracy. The result satisfy :math:`|\sum_i p_i - 1| \le tol` Returns: tuple: - Tensor: [N, K], the exploration policy - int: the number of iterations """ batch_size = value.shape[0] assert value.shape == prior.shape assert c.shape == (batch_size, 1) value[prior == 0] = -MAXIMUM_FLOAT_VALUE v_max = value.max(dim=1, keepdim=True)[0] u = (value - v_max) / c beta = (prior + u).max(dim=1, keepdim=True)[0] converged = False i = 0 while i < 100: i += 1 beta_u = beta - u p = prior / beta_u sum_p = p.sum(dim=1, keepdim=True) diff = sum_p - 1 if (diff < tol).all(): converged = True break d = (p / beta_u).sum(dim=1, keepdim=True) beta = beta + diff / d if not converged: bad = ~(diff < tol).squeeze(1) raise RuntimeError( "calculate_exploration_policy() cannot converge. " "value=%s prior=%s c=%s" % (value[bad], prior[bad], c[bad])) return p, i
[docs]def calculate_kl_exploration_policy(value, prior, c): r"""Calculate exploration policy. This is similar to ``calculate_exploration_policy``, but using :math:`KL(p\|q)` instead of :math:`KL(q\|p)` for regularization. Notation: q: prior policy p: sampling probability v: value The exploration policy is found by minimizing the following: .. math:: p = \arg\min_p \left[ -E_p(v) + c KL(p\|q) \right] which leads to the following solution: .. math:: p_i = \frac{q_i \exp(v_i/c)}{Z} where :math:`Z` is such that :math:`\sum_i p_i = 1` Args: value (Tensor): [N, K] Tensor prior (Tensor): [N, K] Tensor c (Tensor): [N, 1] Tensor Returns: tuple: - Tensor: [N, K], the exploration policy - int: always 0 (to conform with the signature of calculate_exploration_policy) """ batch_size = value.shape[0] assert value.shape == prior.shape assert c.shape == (batch_size, 1) p = prior * (value / c).exp() p = p / p.sum(dim=-1, keepdim=True) return p, 0
[docs]@alf.configurable def create_atari_mcts(observation_spec, action_spec): """Helper function for creating MCTSAlgorithm for atari games.""" def visit_softmax_temperature(num_moves): progress = Trainer.progress() if progress < 0.5: t = 1.0 elif progress < 0.75: t = 0.5 else: t = 0.25 return t * torch.ones_like(num_moves, dtype=torch.float32) return MCTSAlgorithm( observation_spec=observation_spec, action_spec=action_spec, discount=0.997, num_simulations=50, root_dirichlet_alpha=0.25, root_exploration_fraction=0.25, pb_c_init=1.25, pb_c_base=19652, is_two_player_game=False, visit_softmax_temperature_fn=visit_softmax_temperature)
[docs]@alf.configurable class VisitSoftmaxTemperatureByMoves(object): def __init__(self, move_temperature_pairs=[(29, 1.0), (10000, 0.0001)]): """Scheduling the temperature by move. Args: move_temperature_pairs (list[tuple]): each (moves, temperature) pair indicates using this temperature until so many moves have been played in the current game. The moves should be in ascending order. Note that ``num_moves`` used to calculate the temperature starts from 0. """ self._move_temperature_pairs = list( reversed(move_temperature_pairs[:-1])) self._last_temperature = move_temperature_pairs[-1][1] def __call__(self, num_moves): t = torch.full_like( num_moves, self._last_temperature, dtype=torch.float32) for move, temp in self._move_temperature_pairs: t[num_moves <= move] = temp return t
[docs]@alf.configurable def create_board_game_mcts(observation_spec, action_spec, dirichlet_alpha: float, pb_c_init=1.25, num_simulations=800, debug_summaries=False): """Helper function for creating MCTSAlgorithm for board games.""" def visit_softmax_temperature(num_moves): t = torch.ones_like(num_moves, dtype=torch.float32) # paper pseudocode uses 0.0 # Current code does not support 0.0, so use a small value, which should # not make any difference since a difference of 1e-3 for visit probability # translates to about exp(1e-3*1e10) probability ratio. t[num_moves >= 30] = 1e-10 return t return MCTSAlgorithm( observation_spec=observation_spec, action_spec=action_spec, discount=1.0, root_dirichlet_alpha=dirichlet_alpha, root_exploration_fraction=0.25, num_simulations=num_simulations, pb_c_init=pb_c_init, pb_c_base=19652, visit_softmax_temperature_fn=visit_softmax_temperature, known_value_bounds=(-1, 1), is_two_player_game=True)
[docs]@alf.configurable def create_go_mcts(observation_spec, action_spec, debug_summaries): return create_board_game_mcts( observation_spec, action_spec, dirichlet_alpha=0.03, debug_summaries=debug_summaries)
[docs]@alf.configurable def create_chess_mcts(observation_spec, action_spec, debug_summaries): return create_board_game_mcts( observation_spec, action_spec, dirichlet_alpha=0.3, debug_summaries=debug_summaries)
[docs]@alf.configurable def create_shogi_mcts(observation_spec, action_spec, debug_summaries): return create_board_game_mcts( observation_spec, action_spec, dirichlet_alpha=0.15, debug_summaries=debug_summaries)
[docs]@alf.configurable def create_control_mcts(observation_spec, action_spec, num_simulations=50, debug_summaries=False): """Helper function for creating MCTSAlgorithm for control tasks.""" def visit_softmax_temperature(num_moves): progress = Trainer.progress() if progress < 0.5: t = 1.0 elif progress < 0.75: t = 0.5 else: t = 0.25 return t * torch.ones_like(num_moves, dtype=torch.float32) return MCTSAlgorithm( observation_spec=observation_spec, action_spec=action_spec, discount=0.997, num_simulations=num_simulations, root_dirichlet_alpha=0.25, root_exploration_fraction=0.25, pb_c_init=1.25, pb_c_base=19652, is_two_player_game=False, visit_softmax_temperature_fn=visit_softmax_temperature, debug_summaries=debug_summaries)
[docs]@alf.repr_wrapper class VisitSoftmaxTemperatureByProgress(object): def __init__(self, progress_temperature_pairs=[(0.5, 1.0), (0.75, 0.5), (1, 0.25)]): """Scheduling the temperature by training progress. Args: progress_temperature_pairs (list[tuple]): each (progress, temperature) pair indicates using this temperature until this training progress. Note that progress should be in ascending order. """ self._progress_temperature_pairs = progress_temperature_pairs def __call__(self, num_moves): progress = Trainer.progress() for p, t in self._progress_temperature_pairs: if progress < p: break return torch.full_like(num_moves, t, dtype=torch.float32)