Source code for alf.algorithms.sarsa_algorithm

# Copyright (c) 2019 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SARSA Algorithm."""

from absl import logging
import copy
import numpy as np
import torch
import torch.nn as nn

import alf
from alf.algorithms.sac_algorithm import _set_target_entropy
from alf.algorithms.one_step_loss import OneStepTDLoss
from alf.algorithms.rl_algorithm import RLAlgorithm
from alf.data_structures import AlgStep, LossInfo, namedtuple, StepType, TimeStep
from alf.utils import common, dist_utils, losses, math_ops, tensor_utils
import alf.nest.utils as nest_utils
from alf.tensor_specs import TensorSpec

SarsaState = namedtuple(
    'SarsaState', [
        'prev_observation', 'prev_step_type', 'actor', 'critics',
        'target_critics', 'noise'
    ],
    default_value=())
SarsaInfo = namedtuple(
    'SarsaInfo', [
        'reward', 'step_type', 'discount', 'action_distribution', 'actor_loss',
        'critics', 'target_critics', 'neg_entropy'
    ],
    default_value=())
SarsaLossInfo = namedtuple('SarsaLossInfo',
                           ['actor', 'critic', 'alpha', 'neg_entropy'])

nest_map = alf.nest.map_structure


[docs]@alf.configurable class SarsaAlgorithm(RLAlgorithm): r"""SARSA Algorithm. SARSA update Q function using the following loss: .. math:: ||Q(s_t,a_t) - \text{nograd}(r_t + \gamma * Q(s_{t+1}, a_{t+1}))||^2 See https://en.wikipedia.org/wiki/State-action-reward-state-action Currently, this is only implemented for continuous action problems. The policy is dervied by a DDPG/SAC manner by maximizing :math:`Q(a(s_t), s_t)`, where :math:`a(s_t)` is the action. """ def __init__(self, observation_spec, action_spec, actor_network_ctor, critic_network_ctor, reward_spec=TensorSpec(()), num_critic_replicas=2, env=None, config=None, critic_loss_cls=OneStepTDLoss, target_entropy=None, epsilon_greedy=None, use_entropy_reward=False, calculate_priority=False, initial_alpha=1.0, ou_stddev=0.2, ou_damping=0.15, actor_optimizer=None, critic_optimizer=None, alpha_optimizer=None, target_update_tau=0.05, target_update_period=10, use_smoothed_actor=False, dqda_clipping=0., on_policy=False, checkpoint=None, debug_summaries=False, name="SarsaAlgorithm"): """ Args: action_spec (nested BoundedTensorSpec): representing the actions. observation_spec (nested TensorSpec): spec for observation. actor_network_ctor (Callable): Function to construct the actor network. ``actor_network_ctor`` needs to accept ``input_tensor_spec`` and ``action_spec`` as its arguments and return an actor network. The constructed network will be called with ``forward(observation, state)``. critic_network_ctor (Callable): Function to construct the critic network. ``critic_netwrok_ctor`` needs to accept ``input_tensor_spec`` which is a tuple of ``(observation_spec, action_spec)``. The constructed network will be called with ``forward((observation, action), state)``. reward_spec (TensorSpec): a rank-1 or rank-0 tensor spec representing the reward(s). num_critic_replicas (int): number of critics to be used. Default is 2. env (Environment): The environment to interact with. ``env`` is a batched environment, which means that it runs multiple simulations simultaneously. Running multiple environments in parallel is crucial to on-policy algorithms as it increases the diversity of data and decreases temporal correlation. ``env`` only needs to be provided to the root ``Algorithm``. config (TrainerConfig): config for training. ``config`` only needs to be provided to the algorithm which performs ``train_iter()`` by itself. initial_alpha (float|None): If provided, will add ``-alpha*entropy`` to the loss to encourage diverse action. target_entropy (float|Callable|None): If a floating value, it's the target average policy entropy, for updating ``alpha``. If a callable function, then it will be called on the action spec to calculate a target entropy. If ``None``, a default entropy will be calculated. epsilon_greedy (float): a floating value in [0,1], representing the chance of action sampling instead of taking argmax. This can help prevent a dead loop in some deterministic environment like Breakout. Only used for evaluation. If None, its value is taken from ``config.epsilon_greedy`` and then ``alf.get_config_value(TrainerConfig.epsilon_greedy)``. use_entropy_reward (bool): If ``True``, will use alpha*entropy as additional reward. calculate_priority (bool): whether to calculate priority. This is only useful if priority replay is enabled. ou_stddev (float): Only used for DDPG. Standard deviation for the Ornstein-Uhlenbeck (OU) noise added in the default collect policy. ou_damping (float): Only used for DDPG. Damping factor for the OU noise added in the default collect policy. target_update_tau (float): Factor for soft update of the target networks. target_update_period (int): Period for soft update of the target networks. use_smoothed_actor (bool): use a smoothed version of actor for predict and rollout. This option can be used if ``on_policy`` is ``False``. dqda_clipping (float): when computing the actor loss, clips the gradient ``dqda`` element-wise between ``[-dqda_clipping, dqda_clipping]``. Does not perform clipping if ``dqda_clipping == 0``. actor_optimizer (torch.optim.Optimizer): The optimizer for actor. critic_optimizer (torch.optim.Optimizer): The optimizer for critic networks. alpha_optimizer (torch.optim.Optimizer): The optimizer for alpha. Only used if ``initial_alpha`` is not ``None``. on_policy (bool): whether it is used as an on-policy algorithm. checkpoint (None|str): a string in the format of "prefix@path", where the "prefix" is the multi-step path to the contents in the checkpoint to be loaded. "path" is the full path to the checkpoint file saved by ALF. Refer to ``Algorithm`` for more details. debug_summaries (bool): ``True`` if debug summaries should be created. name (str): The name of this algorithm. """ self._calculate_priority = calculate_priority if epsilon_greedy is None: epsilon_greedy = alf.utils.common.get_epsilon_greedy(config) self._epsilon_greedy = epsilon_greedy critic_network = critic_network_ctor( input_tensor_spec=(observation_spec, action_spec)) actor_network = actor_network_ctor( input_tensor_spec=observation_spec, action_spec=action_spec) flat_action_spec = alf.nest.flatten(action_spec) is_continuous = min( map(lambda spec: spec.is_continuous, flat_action_spec)) assert is_continuous, ( "SarsaAlgorithm only supports continuous action." " action_spec: %s" % action_spec) critic_networks = critic_network.make_parallel(num_critic_replicas) if not actor_network.is_distribution_output: noise_process = alf.networks.OUProcess( state_spec=action_spec, damping=ou_damping, stddev=ou_stddev) noise_state = noise_process.state_spec else: noise_process = None noise_state = () super().__init__( observation_spec, action_spec, reward_spec=reward_spec, env=env, is_on_policy=on_policy, config=config, predict_state_spec=SarsaState( noise=noise_state, prev_observation=observation_spec, prev_step_type=alf.TensorSpec((), torch.int32), actor=actor_network.state_spec), train_state_spec=SarsaState( noise=noise_state, prev_observation=observation_spec, prev_step_type=alf.TensorSpec((), torch.int32), actor=actor_network.state_spec, critics=critic_networks.state_spec, target_critics=critic_networks.state_spec, ), checkpoint=checkpoint, debug_summaries=debug_summaries, name=name) self._actor_network = actor_network self._num_critic_replicas = num_critic_replicas self._critic_networks = critic_networks self._target_critic_networks = critic_networks.copy( name='target_critic_networks') self.add_optimizer(actor_optimizer, [actor_network]) self.add_optimizer(critic_optimizer, [critic_networks]) self._log_alpha = None self._use_entropy_reward = False if initial_alpha is not None: if actor_network.is_distribution_output: self._target_entropy = _set_target_entropy( self.name, target_entropy, flat_action_spec) log_alpha = torch.tensor( np.log(initial_alpha), dtype=torch.float32) if alpha_optimizer is None: self._log_alpha = log_alpha else: self._log_alpha = nn.Parameter(log_alpha) self.add_optimizer(alpha_optimizer, [self._log_alpha]) self._use_entropy_reward = use_entropy_reward else: logging.info( "initial_alpha and alpha_optimizer is ignored. " "The `actor_network` needs to output Distribution in " "order to use entropy as regularization or reward") models = copy.copy(critic_networks) target_models = copy.copy(self._target_critic_networks) self._rollout_actor_network = self._actor_network if use_smoothed_actor: assert not on_policy, ("use_smoothed_actor can only be used in " "off-policy training") self._rollout_actor_network = actor_network.copy( name='rollout_actor_network') models.append(self._actor_network) target_models.append(self._rollout_actor_network) self._update_target = common.TargetUpdater( models=models, target_models=target_models, tau=target_update_tau, period=target_update_period) self._dqda_clipping = dqda_clipping self._noise_process = noise_process self._critic_losses = [] for i in range(num_critic_replicas): self._critic_losses.append( critic_loss_cls(debug_summaries=debug_summaries and i == 0)) self._is_rnn = len(alf.nest.flatten(critic_network.state_spec)) > 0 def _trainable_attributes_to_ignore(self): return ["_target_critic_networks", "_rollout_actor_network"] def _get_action(self, actor_network, time_step: TimeStep, state: SarsaState, epsilon_greedy=1.0): action_distribution, actor_state = actor_network( time_step.observation, state=state.actor) if actor_network.is_distribution_output: if epsilon_greedy == 1.0: action = dist_utils.rsample_action_distribution( action_distribution) else: action = dist_utils.epsilon_greedy_sample( action_distribution, epsilon_greedy) noise_state = () else: def _sample(a, noise): if epsilon_greedy >= 1.0: return a + noise else: choose_random_action = (torch.rand(a.shape[:1]) < epsilon_greedy) return torch.where( common.expand_dims_as(choose_random_action, a), a + noise, a) noise, noise_state = self._noise_process(state.noise) action = nest_map(_sample, action_distribution, noise) return action_distribution, action, actor_state, noise_state
[docs] def predict_step(self, inputs: TimeStep, state: SarsaState): action_distribution, action, actor_state, noise_state = self._get_action( self._rollout_actor_network, inputs, state, self._epsilon_greedy) return AlgStep( output=action, state=SarsaState( noise=noise_state, actor=actor_state, prev_observation=inputs.observation, prev_step_type=inputs.step_type), info=SarsaInfo(action_distribution=action_distribution))
[docs] def convert_train_state_to_predict_state(self, state: SarsaState): return state._replace(critics=(), target_critics=())
[docs] def rollout_step(self, inputs: TimeStep, state: SarsaState): if self.on_policy: return self._train_step(inputs, state) if not self._is_rnn: critic_states = state.critics else: _, critic_states = self._critic_networks( (state.prev_observation, inputs.prev_action), state.critics) not_first_step = inputs.step_type != StepType.FIRST critic_states = common.reset_state_if_necessary( state.critics, critic_states, not_first_step) action_distribution, action, actor_state, noise_state = self._get_action( self._rollout_actor_network, inputs, state) if not self._is_rnn: target_critic_states = state.target_critics else: _, target_critic_states = self._target_critic_networks( (inputs.observation, action), state.target_critics) info = SarsaInfo(action_distribution=action_distribution) rl_state = SarsaState( noise=noise_state, prev_observation=inputs.observation, prev_step_type=inputs.step_type, actor=actor_state, critics=critic_states, target_critics=target_critic_states) return AlgStep(action, rl_state, info)
[docs] def train_step(self, inputs: TimeStep, state: SarsaState, rollout_info): return self._train_step(inputs, state)
def _train_step( self, time_step: TimeStep, state: SarsaState, ): not_first_step = time_step.step_type != StepType.FIRST prev_critics, critic_states = self._critic_networks( (state.prev_observation, time_step.prev_action), state.critics) critic_states = common.reset_state_if_necessary( state.critics, critic_states, not_first_step) action_distribution, action, actor_state, noise_state = self._get_action( self._actor_network, time_step, state) critics, _ = self._critic_networks((time_step.observation, action), critic_states) critic = critics.min(dim=1)[0] dqda = nest_utils.grad(action, critic.sum()) def actor_loss_fn(dqda, action): if self._dqda_clipping: dqda = dqda.clamp(-self._dqda_clipping, self._dqda_clipping) loss = 0.5 * losses.element_wise_squared_loss( (dqda + action).detach(), action) loss = loss.sum(list(range(1, loss.ndim))) return loss actor_loss = nest_map(actor_loss_fn, dqda, action) actor_loss = math_ops.add_n(alf.nest.flatten(actor_loss)) neg_entropy = () if self._log_alpha is not None: neg_entropy = dist_utils.compute_log_probability( action_distribution, action) target_critics, target_critic_states = self._target_critic_networks( (time_step.observation, action), state.target_critics) info = SarsaInfo( reward=time_step.reward, step_type=time_step.step_type, discount=time_step.discount, action_distribution=action_distribution, actor_loss=actor_loss, critics=prev_critics, neg_entropy=neg_entropy, target_critics=target_critics.min(dim=1)[0]) rl_state = SarsaState( noise=noise_state, prev_observation=time_step.observation, prev_step_type=time_step.step_type, actor=actor_state, critics=critic_states, target_critics=target_critic_states) return AlgStep(action, rl_state, info)
[docs] def calc_loss(self, info: SarsaInfo): loss = info.actor_loss if self._log_alpha is not None: alpha = self._log_alpha.exp().detach() alpha_loss = self._log_alpha * ( -info.neg_entropy - self._target_entropy).detach() loss = loss + alpha * info.neg_entropy + alpha_loss else: alpha_loss = () # For sarsa, info.critics is actually the critics for the previous step. # And info.target_critics is the critics for the current step. So we # need to rearrange ``experience``` to match the requirement for # `OneStepTDLoss`. step_type0 = info.step_type[0] step_type0 = torch.where(step_type0 == StepType.LAST, torch.tensor(StepType.MID), step_type0) step_type0 = torch.where(step_type0 == StepType.FIRST, torch.tensor(StepType.LAST), step_type0) gamma = self._critic_losses[0].gamma reward = info.reward if self._use_entropy_reward: reward -= gamma * ( self._log_alpha.exp() * info.neg_entropy).detach() shifted_experience = info._replace( discount=tensor_utils.tensor_prepend_zero(info.discount), reward=tensor_utils.tensor_prepend_zero(reward), step_type=tensor_utils.tensor_prepend(info.step_type, step_type0)) critic_losses = [] for i in range(self._num_critic_replicas): critic = tensor_utils.tensor_extend_zero(info.critics[..., i]) target_critic = tensor_utils.tensor_prepend_zero( info.target_critics) loss_info = self._critic_losses[i](shifted_experience, critic, target_critic) critic_losses.append(nest_map(lambda l: l[:-1], loss_info.loss)) critic_loss = math_ops.add_n(critic_losses) not_first_step = (info.step_type != StepType.FIRST).to(torch.float32) critic_loss = critic_loss * not_first_step if self._calculate_priority: valid_n = torch.clamp(not_first_step.sum(dim=0), min=1.0) priority = (critic_loss.sum(dim=0) / valid_n).sqrt() else: priority = () # put critic_loss to scalar_loss because loss will be masked by # ~is_last at train_complete(). The critic_loss here should be # masked by ~is_first instead, which is done above scalar_loss = critic_loss.mean() if self._debug_summaries and alf.summary.should_record_summaries(): with alf.summary.scope(self._name): if self._log_alpha is not None: alf.summary.scalar("alpha", alpha) return LossInfo( loss=loss, scalar_loss=scalar_loss, priority=priority, extra=SarsaLossInfo( actor=info.actor_loss, critic=critic_loss, alpha=alpha_loss, neg_entropy=info.neg_entropy))
[docs] def after_update(self, root_inputs, info: SarsaInfo): self._update_target()