# Copyright (c) 2021 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn
from alf.data_structures import AlgStep, LossInfo
[docs]class AlgorithmInterface(nn.Module):
"""The interface for algorithm.
It is a generic interface for reinforcement learning (RL) and non-RL
algorithms. The key interface functions are:
1. ``predict_step()``: one step of computation of action for evaluation.
2. ``rollout_step()``: one step of computation for rollout. It is used for
collecting experiences during training. Different from ``predict_step``,
``rollout_step`` may include addtional computations for training. An
algorithm could immediately use the collected experiences to update
parameters after one rollout (multiple rollout steps) is performed; or it
can put these collected experiences into a replay buffer.
3. ``train_step()``: only used by algorithms that put experiences into
replay buffers. The training data are sampled from the replay buffer
filled by ``rollout_step()``.
4. ``train_from_unroll()``: perform a training iteration from the unrolled
result.
5. ``train_from_replay_buffer()``: perform a training iteration from a
replay buffer.
6. ``update_with_gradient()``: do one gradient update based on the loss. It
is used by the default ``train_from_unroll()`` and
``train_from_replay_buffer()`` implementations. You can override to
implement your own ``update_with_gradient()``.
7. ``calc_loss()``: calculate loss based on the ``info`` collected from
``rollout_step()`` or ``train_step()``. It
is used by the default implementations of ``train_from_unroll()`` and
``train_from_replay_buffer()``. If you want to use these two functions,
you need to implement ``calc_loss()``.
8. ``after_update()``: called by ``train_iter()`` after every call to
``update_with_gradient()``, mainly for some postprocessing steps such as
copying a training model to a target model in SAC or DQN.
9. ``after_train_iter()``: called by ``train_iter()`` after every call to
``train_from_unroll()`` (on-policy training iter) or
``train_from_replay_buffer`` (off-policy training iter). It's mainly for
training additional modules that have their own training logic (e.g.,
on/off-policy, replay buffers, etc). Other things might also be possible
as long as they should be done once every training iteration.
For algorithms that have additional offline training flows, they can be
implemented by using the following additional interface functions:
10. ``train_step_offline()``: only used by algorithms that has offline
training flows. The training data are sampled from a replay buffer
that is loaded from an offline replay buffer checkpoint.
11. ``calc_loss_offline()``: It calculates the loss based on the ``info``
collected from ``train_step_offline()``.
The offline training flows can be invoked by specifying a valid path to a
replay buffer for ``TrainerConfig.offline_buffer_dir``.
.. note::
A non-RL algorithm will not directly interact with an
environment. The interation loop will always be driven by an
``RLAlgorithm`` that outputs actions and gets rewards. So a
non-RL algorithm is always attached to an ``RLAlgorithm`` and cannot
change the timing of (when to launch) a training iteration. However, it
can have its own logic of a training iteration (e.g.,
``train_from_unroll()`` and ``train_from_replay_buffer()``) which can be
triggered by a parent ``RLAlgorithm`` inside its ``after_train_iter()``.
"""
@property
def path(self):
"""Path from the root algorithm to this algorithm.
Currently, path is useful when an algorithm needs to directly access
the data about itself in replay buffer. There are two types of
data about an algorithm are stored in replay buffer: one is ``rollout_info``,
which is ``AlgStep.info`` returned by rollout_step(), the other is ``state``,
which is the ``state`` argument used to call ``rollout_step()``. The data
in replay buffer is organized as ``Experience`` which includes ``rollout_info``
and ``state``.
Given an experience structure, the input state to ``rollout_step()`` can be
retrieved by:
.. code-block:: python
nest.get_field(experience.state, self.path)
The info from ``rollout_step()`` can be retrieved by:
.. code-block:: python
nest.get_field(experience.rollout_info, self.path)
Returns:
str: path from the root algorithm to this algorithm
"""
raise NotImplementedError()
[docs] def set_path(self, path):
"""Set the path from the root algorithm to this algorithm.
See ``AlgorithmInterface.path`` for description about path.
This function is called by the trainer before training starts.
It needs to be implemented if the algorithm contains some other
sub-algorithms.
If an algorithm does not have any sub-algorithm or its sub-algorithm does
not need to access the root replay buffer directly, it does not implement
this function.
"""
raise NotImplementedError()
@property
def on_policy(self):
"""Whether is on-policy training.
For on-policy training, ``train_step()`` will not be called. And ``info``
passed to ``calc_loss()`` is info collected from ``rollout_step()``.
For off-policy training, ``train_step()`` will be called with the experience
from replay buffer. And ``info`` passed to ``calc_loss()`` is info collected
from ``train_step``.
An algorithm can override this to indicate whether it is an on-policy or
off-policy algorithm. If an algorithm does not override this, it needs to
support both on-policy and off-policy training, which means that ``rollout_step()``
and ``train_step()`` need to have the correct behavior for on-policy and
off-policy training. It can check wether it is on-policy training by
calling this function.
Returns:
bool | None: True if on-policy training, False if off-policy training,
None if not set.
"""
raise NotImplementedError()
[docs] def set_on_policy(self, is_on_policy):
"""Set whether this algorithm is on-policy or not.
Args:
is_on_policy (bool):
"""
raise NotImplementedError()
[docs] def predict_step(self, inputs, state):
"""Predict for one step of inputs.
Args:
inputs (nested Tensor): inputs for prediction.
state (nested Tensor): network state (for RNN).
Returns:
AlgStep:
- output (nested Tensor): prediction result.
- state (nested Tensor): should match ``predict_state_spec``.
- info (nest): information for analyzing the agent. In particular,
if an element of the info is ``alf.summary.render.Image``, it
will be rendered during play. See alf/summary/render.py for
detail.
"""
raise NotImplementedError()
[docs] def rollout_step(self, inputs, state):
"""Rollout for one step of inputs.
It is called to calculate output for every environment step. For on-policy
training, it also needs to generate necessary information for ``calc_loss()``.
For off-policy training, it needs to generate necessary information for
``train_step()``.
Args:
inputs (nested Tensor): inputs for prediction.
state (nested Tensor): network state (for RNN).
Returns:
AlgStep:
- output (nested Tensor): prediction result.
- state (nested Tensor): should match ``rollout_state_spec``.
- info (nested Tensor): For on-policy training it will be temporally
batched and passed as ``info`` for calc_loss(). For off-policy
training, it will be stored into retrieved from replay buffer and
and retrieved for ``train_step()`` as ``rollout_info``.
"""
raise NotImplementedError()
[docs] def train_step(self, inputs, state, rollout_info):
"""Perform one step of training computation.
It is called to calculate output for every time step for a batch of
experience from replay buffer. It also needs to generate necessary
information for ``calc_loss()``.
Args:
inputs (nested Tensor): inputs for train.
state (nested Tensor): consistent with ``train_state_spec``.
rollout_info (nested Tensor): info from ``rollout_step()``. It is
retrieved from replay buffer.
Returns:
AlgStep:
- output (nested Tensor): prediction result.
- state (nested Tensor): should match ``train_state_spec``.
- info (nested Tensor): information for training. It will temporally
batched and passed as ``info`` for calc_loss(). If this is
``LossInfo``, ``calc_loss()`` in ``Algorithm`` can be used.
Otherwise, the user needs to override ``calc_loss()`` to
calculate loss or override ``update_with_gradient()`` to do
customized training.
"""
raise NotImplementedError()
[docs] def calc_loss(self, info):
"""Calculate the loss for one mini-batch.
Args:
info (nest): information collected for training. It is batched
from each ``AlgStep.info`` returned by ``rollout_step()``
(on-policy training) or ``train_step()`` (off-policy training).
The shape of the tensors in info is ``(T, B, ...)``, where T
is the mini-batch length and B is the mini-batch size.
Returns:
LossInfo: loss at each time step for each sample in the
batch. The shapes of the tensors in loss info should be
:math:`(T, B)`.
"""
raise NotImplementedError()
[docs] def preprocess_experience(self, root_inputs, rollout_info, batch_info):
"""This function is called on the experiences obtained from a replay
buffer. An example usage of this function is to calculate advantages and
returns in ``PPOAlgorithm``.
The shapes of tensors in experience are assumed to be :math:`(B, T, ...)`.
Args:
root_inputs (nest): input for rollout_step() of the root algorithm.
This is from replay buffer. Note this is not same as the input
of rollout_step() of self unless self is the root algorithm.
rollout_info (nested Tensor): ``AlgStep.info`` from rollout_step()
for this algorithm.
batch_info (BatchInfo): information about this batch of data
Returns:
tuple:
- processed root_inputs
- processed rollout_info
"""
return root_inputs, rollout_info
[docs] def after_update(self, root_inputs, info):
"""Do things after completing one gradient update (i.e. ``update_with_gradient()``).
This function can be used for post-processings following one minibatch
update, such as copy a training model to a target model in SAC, DQN, etc.
Args:
root_inputs (nest): temporally batched inputs for the ``rollout_step()``
of the root algorithm collected during ``unroll()``.
info (nest): information collected for training.
It is batched from each ``AlgStep.info`` returned by ``rollout_step()``
for on-policy training or ``train_step()`` for off-policy training.
"""
[docs] def after_train_iter(self, root_inputs, rollout_info):
"""Do things after completing one training iteration (i.e. ``train_iter()``
that consists of one or multiple gradient updates). This function can
be used for training additional modules that have their own training logic
(e.g., on/off-policy, replay buffers, etc). These modules should be added
to ``_trainable_attributes_to_ignore`` in the parent algorithm.
Other things might also be possible as long as they should be done once
every training iteration.
This function will serve the same purpose with ``after_update`` if there
is always only one gradient update in each training iteration. Otherwise
it's less frequently called than ``after_update``.
Args:
root_inputs (nest|None): temporally batched inputs for the
``rollout_step()`` of the root algorithm collected during
``unroll()``. In the case where no data is available from
the ``rollout_step()`` (e.g. in a offline pre-training phase
where the online interaction is not started yet) ``root_inputs``
will be None.
rollout_info (nest|None): information collected from
``rollout_step()`` for this algorithm during ``unroll()``.
In the case where no data is available from the
``rollout_step()`` (e.g. in a offline pre-training phase
where the online interaction is not started yet) ``rollout_info``
will be None.
"""
[docs] def train_iter(self):
"""Perform one iteration of training.
Users may choose to implement their own ``train_iter()``.
Returns:
int:
- number of samples being trained on (including duplicates).
"""
[docs] def train_from_unroll(self, experience, train_info):
"""Train given the info collected from ``unroll()``. This function can
be called by any child algorithm that doesn't have the unroll logic but
has a different training logic with its parent.
Args:
experience (Experience): collected during ``unroll()``.
train_info (nest): ``AlgStep.info`` returned by ``rollout_step()``.
Returns:
int: number of steps that have been trained
"""
raise NotImplementedError()
[docs] def train_from_replay_buffer(self, update_global_counter=False):
"""This function can be called by any algorithm that has its own
replay buffer configured.
Args:
update_global_counter (bool): controls whether this function changes
the global counter for summary. If there are multiple
algorithms, then only the parent algorithm should change this
quantity and child algorithms should disable the flag. When it's
``True``, it will affect the counter only if
``config.update_counter_every_mini_batch=True``.
"""
raise NotImplementedError()
[docs] def train_step_offline(self, inputs, state, rollout_info, pre_train=False):
"""Perform one step of offline training computation.
It is called to calculate output for every time step for a batch of
experience from offline replay buffer. It also needs to generate
necessary information for ``calc_loss_offline()``.
Args:
inputs (nested Tensor): inputs for train.
state (nested Tensor): consistent with ``train_state_spec``.
rollout_info (nested Tensor): info from ``rollout_step()``. It is
retrieved from replay buffer.
pre_train (bool): whether in pre_training phase. This flag
can be used for algorithms that need to implement different
training procedures at different phases.
Returns:
AlgStep:
- output (nested Tensor): prediction result.
- state (nested Tensor): should match ``train_state_spec``.
- info (nested Tensor): information for training. It will temporally
batched and passed as ``info`` for calc_loss(). If this is
``LossInfo``, ``calc_loss()`` in ``Algorithm`` can be used.
Otherwise, the user needs to override ``calc_loss()`` to
calculate loss or override ``update_with_gradient()`` to do
customized training.
"""
raise NotImplementedError()
[docs] def calc_loss_offline(self, info, pre_train=False):
"""Calculate the loss for one mini-batch.
Args:
info (nest): information collected for training. It is batched
from each ``AlgStep.info`` returned by ``rollout_step()``
(on-policy training) or ``train_step()`` (off-policy training).
The shape of the tensors in info is ``(T, B, ...)``, where T
is the mini-batch length and B is the mini-batch size.
pre_train (bool): whether in pre_training phase. This flag
can be used for algorithms that need to implement different
training procedures at different phases.
Returns:
LossInfo: loss at each time step for each sample in the
batch. The shapes of the tensors in loss info should be
:math:`(T, B)`.
"""
raise NotImplementedError()