Source code for alf.algorithms.predictive_representation_learner

# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PredictiveRepresentationLearner."""

from typing import Optional
from functools import partial
import torch

import alf
from alf.algorithms.algorithm import Algorithm
from alf.algorithms.config import TrainerConfig
from alf.data_structures import AlgStep, TimeStep, LossInfo, namedtuple
from alf.experience_replayers.replay_buffer import BatchInfo, ReplayBuffer
from alf.nest.utils import convert_device
from alf.networks import Network, LSTMEncodingNetwork, wrap_as_network
from alf.utils import common, dist_utils, tensor_utils
from alf.utils.normalizers import AdaptiveNormalizer
from alf.utils.summary_utils import safe_mean_hist_summary, safe_mean_summary
from alf.tensor_specs import TensorSpec

PredictiveRepresentationLearnerInfo = namedtuple(
    'PredictiveRepresentationLearnerInfo',
    [
        # actual actions taken in the next unroll_steps steps
        # [B, unroll_steps, ...]
        'action',

        # The flag to indicate whether to include this target into loss
        # [B, unroll_steps + 1]
        'mask',

        # nest for targets
        # [B, unroll_steps + 1, ...]
        'target'
    ])


[docs]@alf.configurable class SimpleDecoder(Algorithm): """A simple decoder with elementwise loss between the target and the predicted value. It is used to predict the target value from the given representation. Its loss can be used to train the representation. """ def __init__(self, input_tensor_spec, target_field, decoder_net_ctor, loss_ctor=partial(torch.nn.SmoothL1Loss, reduction='none'), loss_weight=1.0, summarize_each_dimension=False, optimizer=None, normalize_target=False, append_target_field_to_name=True, debug_summaries=False, name="SimpleDecoder"): """ Args: input_tensor_spec (TensorSpec): describing the input tensor. target_field (str): name of the field in the experience to be used as the decoding target. decoder_net_ctor (Callable): called as ``decoder_net_ctor(input_tensor_spec=input_tensor_spec)`` to construct an instance of ``Network`` for decoding. The network should take the latent representation as input and output the predicted value of the target. loss_ctor (Callable): loss function with signature ``loss(y_pred, y_true)``. Note that it should not reduce to a scalar. It should at least keep the batch dimension in the returned loss. loss_weight (float): weight for the loss. optimizer (Optimzer|None): if provided, it will be used to optimize the parameter of decoder_net normalize_target (bool): whether to normalize target. Note that the effect of this is to change the loss. The predicted value itself is not normalized. append_target_field_to_name (bool): whether append target field to the name of the decoder. If True, the actual name used will be ``name.target_field`` debug_summaries (bool): whether to generate debug summaries name (str): name of this instance """ if append_target_field_to_name: name = name + "." + target_field super().__init__( optimizer=optimizer, debug_summaries=debug_summaries, name=name) self._decoder_net = decoder_net_ctor( input_tensor_spec=input_tensor_spec) assert self._decoder_net.state_spec == ( ), "RNN decoder is not suppported" self._summarize_each_dimension = summarize_each_dimension self._target_field = target_field self._loss = loss_ctor() self._loss_weight = loss_weight if normalize_target: self._target_normalizer = AdaptiveNormalizer( self._decoder_net.output_spec, auto_update=False, name=name + ".target_normalizer") else: self._target_normalizer = None
[docs] def get_target_fields(self): return self._target_field
[docs] def train_step(self, repr, state=()): predicted_target = self._decoder_net(repr)[0] return AlgStep( output=predicted_target, state=state, info=predicted_target)
[docs] def predict_step(self, repr, state=()): predicted_target = self._decoder_net(repr)[0] return AlgStep( output=predicted_target, state=state, info=predicted_target)
[docs] def calc_loss(self, target, predicted, mask=None): """Calculate the loss between ``target`` and ``predicted``. Args: target (Tensor): target to be predicted. Its shape is [T, B, ...] predicted (Tensor): predicted target. Its shape is [T, B, ...] mask (bool Tensor): indicating which target should be predicted. Its shape is [T, B]. Returns: LossInfo """ if self._target_normalizer: self._target_normalizer.update(target) target = self._target_normalizer.normalize(target) predicted = self._target_normalizer.normalize(predicted) # self._loss() is not guaranteed to correctly handle more than one batch # dimension (e.g. CrossEntropyLoss), so we need to do some reshaping here. b = predicted.shape[0] * predicted.shape[1] loss = self._loss( predicted.reshape(b, *predicted.shape[2:]), target.reshape(b, *target.shape[2:])) loss = loss.reshape(*predicted.shape[:2], *loss.shape[1:]) if self._debug_summaries and alf.summary.should_record_summaries(): with alf.summary.scope(self._name): def _summarize1(pred, tgt, loss, mask, suffix): if pred.shape == tgt.shape: alf.summary.scalar( "explained_variance" + suffix, tensor_utils.explained_variance(pred, tgt, mask)) safe_mean_hist_summary('predict' + suffix, pred, mask) safe_mean_hist_summary('target' + suffix, tgt, mask) safe_mean_summary("loss" + suffix, loss, mask) def _summarize(pred, tgt, loss, mask, suffix): _summarize1(pred[0], tgt[0], loss[0], mask[0], suffix + "/current") if pred.shape[0] > 1: _summarize1(pred[1:], tgt[1:], loss[1:], mask[1:], suffix + "/future") if loss.ndim == 2: _summarize(predicted, target, loss, mask, '') elif not self._summarize_each_dimension: m = mask if m is not None: m = m.unsqueeze(-1).expand_as(predicted) _summarize(predicted, target, loss, m, '') else: for i in range(predicted.shape[2]): suffix = '/' + str(i) _summarize(predicted[..., i], target[..., i], loss[..., i], mask, suffix) if loss.ndim == 3: loss = loss.mean(dim=2) if mask is not None: loss = loss * mask return LossInfo(loss=loss * self._loss_weight, extra=loss)
[docs]@alf.configurable class PredictiveRepresentationLearner(Algorithm): """Learn representation based on the prediction of future values. ``PredictiveRepresentationLearner`` contains 3 ``Module``s: * encoding_net: it is a ``Network`` that encodes the raw observation to a latent vector. * dynamics_net: it is a ``Network`` that generates the future latent states from the current latent state. * decoder: it is an ``Algorithm`` that decode the target values from the latent state and calcuate the loss. """ def __init__(self, observation_spec, action_spec, num_unroll_steps, decoder_ctor, encoding_net_ctor, dynamics_net_ctor, reward_spec=TensorSpec(()), config: Optional[TrainerConfig] = None, postprocessor=None, encoding_optimizer=None, dynamics_optimizer=None, postprocessor_optimizer=None, checkpoint=None, debug_summaries=False, name="PredictiveRepresentationLearner"): """ Args: observation_spec (nested TensorSpec): describing the observation. action_spec (nested BoundedTensorSpec): describing the action. num_unroll_steps (int): the number of future steps to predict. ``num_unroll_steps`` of 0 means no future prediction and hence ``dynamics_net_ctor`` is ignored. decoder_ctor (Callable|[Callable]): each individual constructor is called as ``decoder_ctor(observation)`` to construct the decoder algorithm. It should follow the ``Algorithm`` interface. In addition to the interface of ``Algorithm``, it should also implement a member function ``get_target_fields()``, which returns a nest of the names of target fields. See ``SimpleDecoder`` for an example of decoder. encoding_net_ctor (Callable): called as ``encoding_net_ctor(observation_spec)`` to construct the encoding ``Network``. The network takes raw observation as input and output the latent representation. encoding_net can be an RNN. dynamics_net_ctor (Callable): called as ``dynamics_net_ctor(action_spec)`` to construct the dynamics ``Network``. It must be an RNN. The constructed network takes action as input and outputs the future latent representation. If the state_spec of the dynamics net is exactly same as the state_spec of the encoding net, the current state of the encoding net will be used as the initial state of the dynamics net. Otherwise, a linear projection will be used to convert the current latent represenation to the initial state for the dynamics net. reward_spec: NOT USED. Only present as representation learner interface to be used with ``Agent``. config: The trainer config. Present as representation learner interface to be used with ``Agent``. postprocessor (None|Callable): If provided, will be called as ``postprocessor(latent)`` to get the actual representation, where ``latent`` is the output from encoding_net. encoding_optimizer (Optimizer|None): if provided, will be used to optimize the parameter for the encoding net. dynamics_optimizer (Optimizer|None): if provided, will be used to optimize the parameter for the dynamics net. postprocessor_optimizer (Optimizer|None): if provided, will be used to optimize the parameter for the postprocessor. checkpoint (None|str): a string in the format of "prefix@path", where the "prefix" is the multi-step path to the contents in the checkpoint to be loaded. "path" is the full path to the checkpoint file saved by ALF. Refer to ``Algorithm`` for more details. debug_summaries (bool): whether to generate debug summaries name (str): name of this instance. """ encoding_net = encoding_net_ctor(observation_spec) super().__init__( train_state_spec=encoding_net.state_spec, config=config, checkpoint=checkpoint, debug_summaries=debug_summaries, name=name) self._encoding_net = encoding_net if encoding_optimizer is not None: self.add_optimizer(encoding_optimizer, [self._encoding_net]) repr_spec = self._encoding_net.output_spec decoder_ctors = common.as_list(decoder_ctor) self._decoders = torch.nn.ModuleList() self._target_fields = [] for decoder_ctor in decoder_ctors: decoder = decoder_ctor( repr_spec, debug_summaries=debug_summaries, append_target_field_to_name=True) target_field = decoder.get_target_fields() self._decoders.append(decoder) assert len(alf.nest.flatten(decoder.train_state_spec)) == 0, ( "RNN decoder is not suported") self._target_fields.append(target_field) if len(self._target_fields) == 1: self._target_fields = self._target_fields[0] self._num_unroll_steps = num_unroll_steps if num_unroll_steps > 0: self._dynamics_net = dynamics_net_ctor(action_spec) self._dynamics_state_dims = alf.nest.map_structure( lambda spec: spec.numel, alf.nest.flatten(self._dynamics_net.state_spec)) assert sum( self._dynamics_state_dims) > 0, ("dynamics_net should be RNN") compatible_state = True try: alf.nest.assert_same_structure(self._dynamics_net.state_spec, self._encoding_net.state_spec) compatible_state = all( alf.nest.flatten( alf.nest.map_structure(lambda s1, s2: s1 == s2, self._dynamics_net.state_spec, self._encoding_net.state_spec))) except Exception: compatible_state = False self._latent_to_dstate_fc = None modules = [self._dynamics_net] if not compatible_state: self._latent_to_dstate_fc = alf.layers.FC( repr_spec.numel, sum(self._dynamics_state_dims)) modules.append(self._latent_to_dstate_fc) if dynamics_optimizer is not None: self.add_optimizer(dynamics_optimizer, modules) if postprocessor is not None: self._postprocessor = postprocessor else: self._postprocessor = alf.math.identity if postprocessor_optimizer is not None: self.add_optimizer(postprocessor_optimizer, [postprocessor]) self._output_spec = wrap_as_network(self._postprocessor, repr_spec).output_spec @property def output_spec(self): return self._output_spec
[docs] def predict_step(self, inputs: TimeStep, state): latent, state = self._encoding_net(inputs.observation, state) latent = self._postprocessor(latent) return AlgStep(output=latent, state=state)
[docs] def rollout_step(self, inputs: TimeStep, state): latent, state = self._encoding_net(inputs.observation, state) latent = self._postprocessor(latent) return AlgStep(output=latent, state=state)
[docs] def predict_multi_step(self, init_latent, actions, target_field=None, state=None): """Perform multi-step predictions based on the initial latent representation and actions sequences. Args: init_latent (Tensor): the latent representation for the initial step of the prediction actions (Tensor): [B, unroll_steps, action_dim] target_field (None|str|[str]): the name or a list if names of the quantities to be predicted. It is used for selecting the corresponding decoder. If None, all the available decoders will be used for generating predictions. state: Returns: prediction (Tensor|[Tensor]): predicted target of shape [B, unroll_steps + 1, d], where d is the dimension of the predicted target. The return is a list of Tensors when there are multiple targets to be predicted. """ num_unroll_steps = actions.shape[1] assert num_unroll_steps > 0 sim_latent = self._multi_step_latent_rollout( init_latent, num_unroll_steps, actions, state) predictions = [] if target_field == None: for decoder in self._decoders: predictions.append(decoder.predict_step(sim_latent).info) else: target_field = common.as_list(target_field) for field in target_field: decoder = self.get_decoder(field) predictions.append(decoder.predict_step(sim_latent).info) return predictions[0] if len(predictions) == 1 else predictions
[docs] def get_decoder(self, target_field): """Get the decoder which predicts the target specified by ``target_name``. Args: target_field (str): the name of the prediction quantity corresponding to the decoder Returns: decoder (Algorithm) """ decoder_ind = common.as_list(self._target_fields).index(target_field) return self._decoders[decoder_ind]
def _multi_step_latent_rollout(self, init_latent, num_unroll_steps, actions, state): """Perform multi-step latent rollout based on the initial latent representation and action sequences. Args: init_latent (Tensor): the latent representation for the initial step of the prediction actions (Tensor): [B, unroll_steps, action_dim] state: Returns: sim_latent (Tensor): a tensor of the shape [(unroll_steps+1)*B, ...], obtained by concataning all the latent states during rollout, including the input initial latent represenataion """ sim_latents = [init_latent] if num_unroll_steps > 0: if self._latent_to_dstate_fc is not None: dstate = self._latent_to_dstate_fc(init_latent) dstate = dstate.split(self._dynamics_state_dims, dim=1) dstate = alf.nest.pack_sequence_as( self._dynamics_net.state_spec, dstate) else: dstate = state for i in range(self._num_unroll_steps): sim_latent, dstate = self._dynamics_net(actions[:, i, ...], dstate) sim_latents.append(sim_latent) sim_latent = alf.nest.map_structure( lambda *tensors: torch.cat(tensors, dim=0), *sim_latents) return sim_latent
[docs] def train_step(self, root_inputs: TimeStep, state, rollout_info): # [B, num_unroll_steps + 1] info = rollout_info targets = common.as_list(info.target) batch_size = root_inputs.step_type.shape[0] latent, state = self._encoding_net(root_inputs.observation, state) sim_latent = self._multi_step_latent_rollout( latent, self._num_unroll_steps, info.action, state) loss = 0 extra = {} for i, decoder in enumerate(self._decoders): # [num_unroll_steps + 1)*B, ...] train_info = decoder.train_step(sim_latent).info train_info_spec = dist_utils.extract_spec(train_info) train_info = dist_utils.distributions_to_params(train_info) train_info = alf.nest.map_structure( lambda x: x.reshape(self._num_unroll_steps + 1, batch_size, *x. shape[1:]), train_info) # [num_unroll_steps + 1, B, ...] train_info = dist_utils.params_to_distributions( train_info, train_info_spec) target = alf.nest.map_structure(lambda x: x.transpose(0, 1), targets[i]) loss_info = decoder.calc_loss(target, train_info, info.mask.t()) loss_info = alf.nest.map_structure(lambda x: x.mean(dim=0), loss_info) loss += loss_info.loss extra[decoder.name] = loss_info.extra loss_info = LossInfo(loss=loss, extra=extra) latent = self._postprocessor(latent) return AlgStep(output=latent, state=state, info=loss_info)
[docs] @torch.no_grad() def preprocess_experience(self, root_inputs, rollout_info, batch_info: BatchInfo): """Fill experience.rollout_info with PredictiveRepresentationLearnerInfo Note that the shape of experience is [B, T, ...]. The target is a Tensor (or a nest of Tensors) when there is only one decoder. When there are multiple decorders, the target is a list, and each of its element is a Tensor (or a nest of Tensors), which is used as the target for the corresponding decoder. """ assert batch_info != () replay_buffer: ReplayBuffer = batch_info.replay_buffer mini_batch_length = root_inputs.step_type.shape[1] with alf.device(replay_buffer.device): # [B, 1] positions = convert_device(batch_info.positions).unsqueeze(-1) # [B, 1] env_ids = convert_device(batch_info.env_ids).unsqueeze(-1) # [B, T] positions = positions + torch.arange(mini_batch_length) # [B, T] steps_to_episode_end = replay_buffer.steps_to_episode_end( positions, env_ids) # [B, T] episode_end_positions = positions + steps_to_episode_end # [B, T, unroll_steps+1] positions = positions.unsqueeze(-1) + torch.arange( self._num_unroll_steps + 1) # [B, 1, 1] env_ids = env_ids.unsqueeze(-1) # [B, T, 1] episode_end_positions = episode_end_positions.unsqueeze(-1) # [B, T, unroll_steps+1] mask = positions <= episode_end_positions # [B, T, unroll_steps+1] positions = torch.min(positions, episode_end_positions) # [B, T, unroll_steps+1, ...] target = replay_buffer.get_field(self._target_fields, env_ids, positions) # [B, T, unroll_steps] action = replay_buffer.get_field('prev_action', env_ids, positions[:, :, 1:]) rollout_info = PredictiveRepresentationLearnerInfo( action=action, mask=mask, target=target) rollout_info = convert_device(rollout_info) return root_inputs, rollout_info