# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Encoding algorithm."""
from typing import Optional
from alf.algorithms.config import TrainerConfig
import alf
from alf.algorithms.algorithm import Algorithm
from alf.data_structures import AlgStep, LossInfo, TimeStep
from alf.networks import EncodingNetwork
from alf.nest import map_structure, flatten
from alf.nest.utils import get_nested_field
[docs]@alf.configurable
class EncodingAlgorithm(Algorithm):
"""Basic encoding algorithm.
It uses the provided encoding network to computed the representation. It
also supports the training of the encoding network by using some of its output
as losses.
"""
def __init__(self,
observation_spec,
action_spec,
reward_spec=alf.TensorSpec(()),
encoder_cls=EncodingNetwork,
time_step_as_input=False,
output_fields=None,
loss_fields=None,
loss_weights=None,
optimizer=None,
config: Optional[TrainerConfig] = None,
checkpoint=None,
debug_summaries=False,
name="EncodingAlgorithm"):
"""
Args:
observation_spec (nested TensorSpec): representing the observations.
action_spec (nested BoundedTensorSpec): not used
encoder_cls (type): The class or function to create the encoder. It
can be called using ``encoder_cls(input_tensor_spec)``.
time_step_as_input (bool): If True, use the whole TimeStep strucuture
as the input to the encoder instead of the observation.
output_fields (None | nested str): if None, all the output from the
encoder will be used as the output. Otherwise, ``output_fields``
will be used to retrieve the fields from the encoder output.
loss_fields (None | nested str): there is not loss if this is None.
Otherwise, ``loss_fields`` will be used to retrieve fields from
encoder output and use them as loss. Note that those fields must
be scalar.
loss_weights (None | nested str): if provided, must have the same
structure as ``loss_fields`` and will be used as weights for
the corresponding loss values.
config: The trainer config. Present as representation learner
interface to be used with ``Agent``.
optimizer (torch.optim.Optimizer): if provided, will be used to optimize
the parameters of encoder.
checkpoint (None|str): a string in the format of "prefix@path",
where the "prefix" is the multi-step path to the contents in the
checkpoint to be loaded. "path" is the full path to the checkpoint
file saved by ALF. Refer to ``Algorithm`` for more details.
debug_summaries (bool): True if debug summaries should be created.
name (str): Name of this algorithm.
"""
if time_step_as_input:
time_step_spec = alf.data_structures.time_step_spec(
observation_spec, action_spec, reward_spec)
encoder = encoder_cls(input_tensor_spec=time_step_spec)
else:
encoder = encoder_cls(input_tensor_spec=observation_spec)
super().__init__(
train_state_spec=encoder.state_spec,
optimizer=optimizer,
config=config,
checkpoint=checkpoint,
debug_summaries=debug_summaries,
name=name)
self._time_step_as_input = time_step_as_input
self._encoder = encoder
output_spec = encoder.output_spec
if output_fields is not None:
output_spec = get_nested_field(output_spec, output_fields)
self._output_spec = output_spec
if loss_fields is not None:
# make sure loss_fields can be found in output_spec
loss_specs = get_nested_field(output_spec, loss_fields)
assert all(
flatten(
map_structure(lambda spec: spec.shape is (), loss_specs))
), ("The losses should be scalars. Got: %s" % str(loss_specs))
if loss_weights is not None:
alf.nest.assert_same_structure(loss_weights, loss_fields)
self._output_fields = output_fields
self._loss_fields = loss_fields
self._loss_weights = loss_weights
@property
def output_spec(self):
return self._output_spec
[docs] def predict_step(self, inputs: TimeStep, state):
"""override predict_step
Args:
inputs (TimeStep): time step structure
state (nested Tensor): network state for ``encoder``
Returns:
AlgStep:
- output: encoding result
- state: rnn state from ``encoder``
"""
if self._time_step_as_input:
output, state = self._encoder(inputs, state=state)
else:
output, state = self._encoder(inputs.observation, state=state)
if self._output_fields is not None:
output = get_nested_field(output, self._output_fields)
return AlgStep(output=output, state=state)
[docs] def rollout_step(self, inputs, state):
"""override rollout_step
Args:
inputs (TimeStep): time step structure
state (nested Tensor): network state for ``encoder``
Returns:
AlgStep:
- output: encoding result
- state: rnn state from ``encoder``
- info: LossInfo
"""
if not self.on_policy:
return self.predict_step(inputs, state)
else:
return self.train_step(inputs, state, None)
[docs] def train_step(self, inputs: TimeStep, state, rollout_info=None):
"""override train_step
Args:
inputs (TimeStep): time step structure
state (nested Tensor): network state for ``encoder``
rollout_info: not used
Returns:
AlgStep:
- output: encoding result
- state: rnn state from ``encoder``
- info: LossInfo
"""
if self._time_step_as_input:
output, state = self._encoder(inputs, state=state)
else:
output, state = self._encoder(inputs.observation, state=state)
if self._loss_fields is not None:
losses = get_nested_field(output, self._loss_fields)
if self._loss_weights is not None:
loss = sum(
flatten(
map_structure(lambda w, l: w * l, self._loss_weights,
losses)))
else:
loss = sum(flatten(losses))
info = LossInfo(loss=loss, extra=losses)
else:
info = LossInfo()
if self._output_fields is not None:
output = get_nested_field(output, self._output_fields)
return AlgStep(output=output, state=state, info=info)