Source code for alf.algorithms.actor_critic_algorithm

# Copyright (c) 2019 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Actor critic algorithm."""

import torch

import alf
from alf.algorithms.on_policy_algorithm import OnPolicyAlgorithm
from alf.networks import ActorDistributionNetwork, ValueNetwork
from alf.algorithms.actor_critic_loss import ActorCriticLoss
from alf.data_structures import TimeStep, AlgStep, namedtuple
from alf.utils import common, dist_utils, tensor_utils
from alf.tensor_specs import TensorSpec
from .config import TrainerConfig

ActorCriticState = namedtuple(
    "ActorCriticState", ["actor", "value"], default_value=())

ActorCriticInfo = namedtuple(
    "ActorCriticInfo", [
        "step_type", "discount", "reward", "action", "log_prob",
        "action_distribution", "value", "reward_weights"
    ],
    default_value=())


[docs]@alf.configurable class ActorCriticAlgorithm(OnPolicyAlgorithm): """Actor critic algorithm.""" def __init__(self, observation_spec, action_spec, reward_spec=TensorSpec(()), reward_weights=None, actor_network_ctor=ActorDistributionNetwork, value_network_ctor=ValueNetwork, epsilon_greedy=None, env=None, config: TrainerConfig = None, loss=None, loss_class=ActorCriticLoss, optimizer=None, checkpoint=None, debug_summaries=False, name="ActorCriticAlgorithm"): """ Args: observation_spec (nested TensorSpec): representing the observations. action_spec (nested BoundedTensorSpec): representing the actions. reward_spec (TensorSpec): a rank-1 or rank-0 tensor spec representing the reward(s). reward_weights (None|list[float]): this is only used when the reward is multidimensional. In that case, the weighted sum of the v values is used for training the actor if reward_weights is not None. Otherwise, the sum of the v values is used. env (Environment): The environment to interact with. env is a batched environment, which means that it runs multiple simulations simultateously. env only needs to be provided to the root Algorithm. epsilon_greedy (float): a floating value in [0,1], representing the chance of action sampling instead of taking argmax. This can help prevent a dead loop in some deterministic environment like Breakout. Only used for evaluation. If None, its value is taken from ``config.epsilon_greedy`` and then ``alf.get_config_value(TrainerConfig.epsilon_greedy)``. config (TrainerConfig): config for training. config only needs to be provided to the algorithm which performs ``train_iter()`` by itself. actor_network_ctor (Callable): Function to construct the actor network. ``actor_network_ctor`` needs to accept ``input_tensor_spec`` and ``action_spec`` as its arguments and return an actor network. The constructed network will be called with ``forward(observation, state)``. value_network_ctor (None | Callable): Function to construct the value network. ``value_network_ctor`` needs to accept ``input_tensor_spec`` as its arguments and return a value netwrok. The contructed network will be called with ``forward(observation, state)`` and returns value tensor for each observation given observation and network state. Note that if the algorithm is constructed for evaluation or deployment only, the value_network_ctor can be set to None and the value network will not be constructed at all. loss (None|ActorCriticLoss): an object for calculating loss. If None, a default loss of class loss_class will be used. loss_class (type): the class of the loss. The signature of its constructor: ``loss_class(debug_summaries)`` optimizer (torch.optim.Optimizer): The optimizer for training checkpoint (None|str): a string in the format of "prefix@path", where the "prefix" is the multi-step path to the contents in the checkpoint to be loaded. "path" is the full path to the checkpoint file saved by ALF. Refer to ``Algorithm`` for more details. debug_summaries (bool): True if debug summaries should be created. name (str): Name of this algorithm. """ if epsilon_greedy is None: epsilon_greedy = alf.utils.common.get_epsilon_greedy(config) self._epsilon_greedy = epsilon_greedy actor_network = actor_network_ctor( input_tensor_spec=observation_spec, action_spec=action_spec) value_network = None if value_network_ctor is not None: value_network = value_network_ctor( input_tensor_spec=observation_spec) if reward_spec.numel > 1: value_network = value_network.make_parallel( reward_spec.numel) # value->[B,n] super(ActorCriticAlgorithm, self).__init__( observation_spec=observation_spec, action_spec=action_spec, reward_spec=reward_spec, reward_weights=reward_weights, predict_state_spec=ActorCriticState( actor=actor_network.state_spec), train_state_spec=ActorCriticState( actor=actor_network.state_spec, value=value_network.state_spec if value_network else ()), env=env, config=config, optimizer=optimizer, checkpoint=checkpoint, debug_summaries=debug_summaries, name=name) self._actor_network = actor_network self._value_network = value_network if loss is None: loss = loss_class(debug_summaries=debug_summaries) self._loss = loss # The following checkpoint loading hook handles the case when value # network is not constructed. In this case the value network paramters # present in the checkpoint should be ignored. def _deployment_hook(state_dict, prefix: str, unused_loacl_metadata, unused_strict, unused_missing_keys, unused_unexpected_keys, unused_error_msgs): to_delete = [] for key in state_dict: if not key.startswith(prefix): continue if self._value_network is None: if key[len(prefix):].startswith("_value_network"): to_delete.append(key) for key in to_delete: state_dict.pop(key) self._register_load_state_dict_pre_hook(_deployment_hook)
[docs] def convert_train_state_to_predict_state(self, state): return state._replace(value=())
[docs] def predict_step(self, inputs: TimeStep, state: ActorCriticState): """Predict for one step.""" action_dist, actor_state = self._actor_network( inputs.observation, state=state.actor) action = dist_utils.epsilon_greedy_sample(action_dist, self._epsilon_greedy) return AlgStep( output=action, state=ActorCriticState(actor=actor_state), info=ActorCriticInfo(action_distribution=action_dist))
[docs] def rollout_step(self, inputs: TimeStep, state: ActorCriticState): """Rollout for one step.""" value, value_state = self._value_network( inputs.observation, state=state.value) action_distribution, actor_state = self._actor_network( inputs.observation, state=state.actor) action, log_prob = dist_utils.sample_action_distribution( action_distribution, return_log_prob=True) if self.has_multidim_reward(): reward_weights = tensor_utils.tensor_extend_new_dim( self.reward_weights, dim=0, n=value.shape[0]) else: reward_weights = () return AlgStep( output=action, state=ActorCriticState(actor=actor_state, value=value_state), info=ActorCriticInfo( action=common.detach(action), log_prob=common.detach(log_prob), value=value, step_type=inputs.step_type, reward=inputs.reward, discount=inputs.discount, action_distribution=action_distribution, reward_weights=reward_weights))
[docs] def calc_loss(self, info: ActorCriticInfo): """Calculate loss.""" return self._loss(info)