Source code for alf.algorithms.agent_helpers

# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Some helper functions for constructing an Agent instance."""

import os

import torch

import alf
from alf.data_structures import LossInfo
from alf.utils.math_ops import add_ignore_empty


def _make_alg_experience(experience, name):
    """Given an experience, extracts the ``rollout_info`` field for an
    algorithm.
    """
    if experience.rollout_info == ():
        rollout_info = ()
    else:
        rollout_info = getattr(experience.rollout_info, name)
    return experience._replace(
        rollout_info=rollout_info,
        rollout_info_field=experience.rollout_info_field + '.' + name)


[docs]class AgentHelper(object): def __init__(self, state_ctor): """Create three state specs given the state creator.""" self._train_state_spec = state_ctor() self._rollout_state_spec = state_ctor() self._predict_state_spec = state_ctor() self._alg_to_field_mapping = dict()
[docs] def register_algorithm(self, alg, alg_field): """Collect state specs from algorithms. For code conciseness, we collect all three state specs even though some of them will not be used during ``unroll`` or ``train``. This function also registers ``alg`` with ``alg_field``. Args: alg (Algorithm): a child algorithm in the agent. alg_field (str): the corresponding algorithm field in an ``AgentState`` or ``AgentInfo``. """ self._alg_to_field_mapping[alg] = alg_field if alg_field in self._train_state_spec._fields: self._train_state_spec = self._train_state_spec._replace( **{alg_field: alg.train_state_spec}) self._rollout_state_spec = self._rollout_state_spec._replace( **{alg_field: alg.rollout_state_spec}) self._predict_state_spec = self._predict_state_spec._replace( **{alg_field: alg.predict_state_spec})
def _get_algorithm_field(self, alg): assert alg in self._alg_to_field_mapping, \ "Should first register this algorithm %s!" % alg.name return self._alg_to_field_mapping[alg]
[docs] def state_specs(self): """Return the state specs collected from child algorithms.""" return dict( train_state_spec=self._train_state_spec, rollout_state_spec=self._rollout_state_spec, predict_state_spec=self._predict_state_spec)
[docs] @staticmethod def accumulate_algorithm_rewards(rewards, weights, names, summary_prefix, summarize_fn): """Sum a list of rewards by their weights. Also summarize the rewards statistics given their names. Args: rewards (list[Tensor]): a list of rewards tensors weights (list[float]): a list of floating numbers names (list[str]): a list of reward names summary_prefix (str): a string prefix for summary summarize_fn (Callable): a summarize function that accepts a name and a reward. Returns: Tensor: A single reward after accumulation. """ assert len(rewards) > 0 reward = torch.zeros_like(rewards[0]) for r, w, name in zip(rewards, weights, names): summarize_fn(os.path.join(summary_prefix, name), r) reward += w * r if len(rewards) > 1: summarize_fn(os.path.join(summary_prefix, "overall"), reward) return reward
[docs] def accumulate_loss_info(self, algorithms, train_info, offline=False, pre_train=False): """Given an overall Agent training info that contains various training infos for different algorithms, compute the accumulated loss info for updating parameters. Args: algorithms (list[Algorithm]): the list of algorithms whose loss infos are to be accumulated. experience (Experience): experience used for gradient update. train_info (nested Tensor): information collected for training algorithms. It is batched from each ``AlgStep.info`` returned by ``train_step()`` or ``rollout_step()``. offline (bool): whether the accumulation is done for offline RL part or the online RL part. pre_train (bool): whether in pre_training phase. This flag can be used for algorithms that need to implement different training procedures at different phases. Returns: LossInfo: the accumulated loss info. """ def _update_loss(loss_info, algorithm, name): info = getattr(train_info, name) if not offline: new_loss_info = algorithm.calc_loss(info) else: new_loss_info = algorithm.calc_loss_offline(info, pre_train) if loss_info is None: return new_loss_info._replace( extra={name: new_loss_info.extra}) else: loss_info.extra[name] = new_loss_info.extra return LossInfo( loss=add_ignore_empty(loss_info.loss, new_loss_info.loss), scalar_loss=add_ignore_empty(loss_info.scalar_loss, new_loss_info.scalar_loss), extra=loss_info.extra, priority=add_ignore_empty(loss_info.priority, new_loss_info.priority)) loss_info = None for alg in algorithms: field = self._get_algorithm_field(alg) loss_info = _update_loss(loss_info, alg, field) assert loss_info is not None, "No loss info is calculated!" return loss_info
[docs] def after_update(self, algorithms, root_inputs, train_info): """For each provided algorithm, call its ``after_update()`` to do things after the agent completes one gradient update (i.e. ``update_with_gradient()``). Args: algorithms (list[Algorithm]): the list of algorithms whose ``after_update`` is to be called. root_inputs (TimeStep): experience used for the gradient update. train_info (AgentInfo): information collected for training algorithms. It is batched from each ``AlgStep.info`` returned by ``train_step()`` or ``rollout_step()``. """ for alg in algorithms: field = self._get_algorithm_field(alg) info = getattr(train_info, field) alg.after_update(root_inputs, info)
[docs] def after_train_iter(self, algorithms, root_inputs, rollout_info=None): """For each provided algorithm, call its ``after_train_iter()`` to do things after the agent finishes one training iteration (i.e., ``train_iter()``). Args: algorithms (list[Algorithm]): the list of algorithms whose ``after_train_iter`` is to be called. root_inputs (TimeStep): experience collected from ``rollout_step()``. rollout_info (AgentInfo): information collected for training algorithms. It is batched from each ``AlgStep.info`` returned by ``rollout_step()``. """ for alg in algorithms: field = self._get_algorithm_field(alg) info = (None if rollout_info is None else getattr( rollout_info, field)) alg.after_train_iter(root_inputs, info)
[docs] def set_path(self, path): """Set the path for the sub-algorithms.""" prefix = path if path: prefix = prefix + '.' for alg, name in self._alg_to_field_mapping.items(): alg.set_path(path + name)