Source code for alf.algorithms.causal_bc_algorithm

# Copyright (c) 2022 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Causal Behavior Cloning Algorithm."""

import torch
import alf
from alf.algorithms.config import TrainerConfig
from alf.algorithms.off_policy_algorithm import OffPolicyAlgorithm
from alf.data_structures import TimeStep, LossInfo, namedtuple
from alf.data_structures import AlgStep
from alf.networks import ActorNetwork, EncodingNetwork
from alf.tensor_specs import TensorSpec, BoundedTensorSpec
from alf.utils import dist_utils, tensor_utils

BcState = namedtuple("BcState", ["actor"], default_value=())

BcInfo = namedtuple(
    "BcInfo", ["actor", "discriminator", "target"], default_value=())

BcLossInfo = namedtuple(
    "LossInfo", ["actor", "discriminator"], default_value=())


[docs]@alf.configurable class CausalBcAlgorithm(OffPolicyAlgorithm): r"""Causal behavior cloning algorithm. This is the implementation of ResiduIL algorithm proposed in the following paper: :: Swamy et al. Causal Imitation Learning under Temporally Correlated Noise, ICML 2022 """ def __init__(self, observation_spec, action_spec: BoundedTensorSpec, reward_spec=TensorSpec(()), actor_network_cls=ActorNetwork, discriminator_network_cls=EncodingNetwork, actor_optimizer=None, discriminator_optimizer=None, f_norm_penalty_weight=1e-3, bc_regulatization_weight=5e-2, env=None, config: TrainerConfig = None, checkpoint=None, debug_summaries=False, epsilon_greedy=None, name="CausalBcAlgorithm"): """ Args: observation_spec (nested TensorSpec): representing the observations. action_spec (nested BoundedTensorSpec): representing the actions; can be a mixture of discrete and continuous actions. The number of continuous actions can be arbitrary while only one discrete action is allowed currently. If it's a mixture, then it must be a tuple/list ``(discrete_action_spec, continuous_action_spec)``. reward_spec (Callable): a rank-1 or rank-0 tensor spec representing the reward(s). For interface compatiblity purpose. Not actually used in CausalBcAlgorithm. actor_network_cls (Callable): is used to construct the actor network. The constructed actor network is a determinstic network and will be used to generate continuous actions. discriminator_network_cls (Callable): is used to construct the discriminator network. The discrimonator is trained in a way that is adversarial to the training of the policy, to help with the learning of a robust policy. It takes the observation from the previous time step to generate the lagrange multiplier for the current step. actor_optimizer (torch.optim.optimizer): The optimizer for actor. discriminator_optimizer (torch.optim.optimizer): the optimizer for discriminator. f_norm_penalty_weight (float): penalty weight for the output of the discriminator. bc_regulatization_weight (float): weight for the squared prediction error based regularization term. env (Environment): The environment to interact with. ``env`` is a batched environment, which means that it runs multiple simulations simultateously. ``env` only needs to be provided to the root algorithm. config (TrainerConfig): config for training. It only needs to be provided to the algorithm which performs ``train_iter()`` by itself. checkpoint (None|str): a string in the format of "prefix@path", where the "prefix" is the multi-step path to the contents in the checkpoint to be loaded. "path" is the full path to the checkpoint file saved by ALF. Refer to ``Algorithm`` for more details. debug_summaries (bool): True if debug summaries should be created. epsilon_greedy (float): a floating value in [0,1], representing the chance of action sampling instead of taking argmax. This can help prevent a dead loop in some deterministic environment like Breakout. Only used for evaluation. If None, its value is taken from ``config.epsilon_greedy`` and then ``alf.get_config_value(TrainerConfig.epsilon_greedy)``. name (str): The name of this algorithm. """ if epsilon_greedy is None: epsilon_greedy = alf.utils.common.get_epsilon_greedy(config) self._epsilon_greedy = epsilon_greedy actor_network = actor_network_cls( input_tensor_spec=observation_spec, action_spec=action_spec) discriminator_network = discriminator_network_cls( input_tensor_spec=observation_spec) action_state_spec = actor_network.state_spec super().__init__( observation_spec=observation_spec, action_spec=action_spec, reward_spec=reward_spec, train_state_spec=BcState(actor=action_state_spec), predict_state_spec=BcState(actor=action_state_spec), reward_weights=None, env=env, config=config, checkpoint=checkpoint, debug_summaries=debug_summaries, name=name) self._actor_network = actor_network self._discriminator_network = discriminator_network if actor_optimizer is not None and actor_network is not None: self.add_optimizer(actor_optimizer, [actor_network]) self._actor_optimizer = actor_optimizer if discriminator_optimizer is not None and discriminator_network is not None: self.add_optimizer(discriminator_optimizer, [discriminator_network]) self._discriminator_optimizer = discriminator_optimizer self._bc_regulatization_weight = bc_regulatization_weight self._f_norm_penalty_weight = f_norm_penalty_weight def _predict_action(self, observation, state): action_dist, actor_network_state = self._actor_network( observation, state=state) return action_dist, actor_network_state
[docs] def predict_step(self, inputs: TimeStep, state: BcState): action_dist, new_state = self._predict_action( inputs.observation, state=state.actor) action = dist_utils.epsilon_greedy_sample(action_dist, self._epsilon_greedy) return AlgStep(output=action, state=BcState(actor=new_state))
[docs] def residuIL_loss(self, targets, predictions, pred_residuals): # train policy (detach discriminator) target_prediction_differences = targets - predictions # bc_regularization is optionally used according to Appendix of the # paper (Tabel 4 and 5) policy_loss = ( 2 * (target_prediction_differences) * pred_residuals.detach() ).mean(-1) + self._bc_regulatization_weight * ( torch.square(target_prediction_differences)).mean(-1) # train discriminator (detach policy) discriminator_loss = -( 2 * (target_prediction_differences).detach() * pred_residuals - pred_residuals * pred_residuals).mean(-1) # f_norm_penalty is used according to Appendix of the paper (Tabel 4 and 5) discriminator_loss = ( discriminator_loss + self._f_norm_penalty_weight * torch.linalg.norm(pred_residuals)) return policy_loss, discriminator_loss
[docs] def train_step_offline(self, inputs: TimeStep, state, rollout_info, pre_train=False): action_dist, new_state = self._predict_action( inputs.observation, state=state.actor) predictions = dist_utils.get_rmode(action_dist) pred_residuals, _ = self._discriminator_network(inputs.observation) info = BcInfo( actor=predictions, discriminator=pred_residuals, target=rollout_info.action) return AlgStep( rollout_info.action, state=BcState(actor=new_state), info=info)
[docs] def calc_loss_offline(self, info, pre_train=False): #[T, B, action_dim] predictions = info.actor pred_residuals = info.discriminator targets = info.target actor_loss, discriminator_loss = self.residuIL_loss( targets[1:], predictions[1:], pred_residuals[:-1]) if self._debug_summaries and alf.summary.should_record_summaries(): with alf.summary.scope(self._name): alf.summary.scalar("actor_loss", actor_loss.mean()) alf.summary.scalar("discriminator_loss", discriminator_loss.mean()) loss = actor_loss + discriminator_loss loss = tensor_utils.tensor_extend_zero(loss) return LossInfo( loss=loss, extra=BcLossInfo( actor=tensor_utils.tensor_extend_zero(actor_loss), discriminator=tensor_utils.tensor_extend_zero( discriminator_loss)))