Source code for alf.algorithms.dynamics_learning_algorithm

# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import namedtuple
from typing import Callable, Optional, Any
import torch

import alf
from alf.algorithms.algorithm import Algorithm
from alf.data_structures import (AlgStep, Experience, LossInfo, namedtuple,
                                 StepType, TimeStep)
from alf.nest import nest
from alf.nest.utils import NestConcat, get_outer_rank
from alf.networks import Network, EncodingNetwork, DynamicsNetwork
from alf.tensor_specs import TensorSpec
from alf.utils import dist_utils, losses, math_ops, spec_utils, tensor_utils

DynamicsState = namedtuple(
    "DynamicsState", ["feature", "network"], default_value=())
DynamicsInfo = namedtuple("DynamicsInfo", ["loss", "dist"], default_value=())


[docs]@alf.configurable class DynamicsLearningAlgorithm(Algorithm): """Base Dynamics Learning Module This module learns the dynamics of environment with a determinstic model. """ def __init__(self, train_state_spec, action_spec, feature_spec, hidden_size=256, num_replicas=1, dynamics_network: DynamicsNetwork = None, checkpoint=None, name="DynamicsLearningAlgorithm"): """Create a DynamicsLearningAlgorithm. Args: hidden_size (int|tuple): size of hidden layer(s) dynamics_network (Network): network for predicting the change of the next feature based on the previous feature and action. It should accept input with spec of the format [feature_spec, encoded_action_spec] and output a tensor of the shape feature_spec. For discrete action case, encoded_action is a one-hot representation of the action. For continuous action, encoded action is the original action. checkpoint (None|str): a string in the format of "prefix@path", where the "prefix" is the multi-step path to the contents in the checkpoint to be loaded. "path" is the full path to the checkpoint file saved by ALF. Refer to ``Algorithm`` for more details. """ super().__init__( train_state_spec=train_state_spec, checkpoint=checkpoint, name=name) flat_action_spec = nest.flatten(action_spec) assert len(flat_action_spec) == 1, "doesn't support nested action_spec" flat_feature_spec = nest.flatten(feature_spec) assert len( flat_feature_spec) == 1, "doesn't support nested feature_spec" action_spec = flat_action_spec[0] if action_spec.is_discrete: self._num_actions = action_spec.maximum - action_spec.minimum + 1 else: self._num_actions = action_spec.shape[-1] self._action_spec = action_spec self._feature_spec = feature_spec self._num_replicas = num_replicas if isinstance(hidden_size, int): hidden_size = (hidden_size, ) if dynamics_network is None: encoded_action_spec = TensorSpec((self._num_actions, ), dtype=torch.float32) dynamics_network = DynamicsNetwork( name="dynamics_net", input_tensor_spec=(feature_spec, encoded_action_spec), preprocessing_combiner=NestConcat(), fc_layer_params=hidden_size, output_tensor_spec=flat_feature_spec[0]) if num_replicas > 1: self._dynamics_network = dynamics_network.make_parallel( num_replicas) else: self._dynamics_network = dynamics_network @property def num_replicas(self): return self._num_replicas def _encode_action(self, action): if self._action_spec.is_discrete: return torch.nn.functional.one_hot( action, num_classes=self._num_actions) else: return action
[docs] def update_state(self, time_step: TimeStep, state: DynamicsState): """Update the state based on TimeStep data. This function is mainly used during rollout together with a planner. Args: time_step (TimeStep): input data for dynamics learning state (DynamicsState): state for DynamicsLearningAlgorithm (previous observation) Returns: state (DynamicsState): updated dynamics state """ pass
[docs] def get_state_specs(self): """Get the state specs of the current module. This function is mainly used for constructing the nested state specs by the upper-level module. """ raise NotImplementedError
[docs] def predict_step(self, time_step: TimeStep, state: DynamicsState): """Predict the current observation using ``time_step.prev_action`` and the feature of the previous observation from ``state``. Args: time_step (TimeStep): input data for dynamics learning state (DynamicsState): state for dynamics learning Returns: AlgStep: output: state (DynamicsState): info (DynamicsInfo): """ raise NotImplementedError
[docs] def train_step(self, time_step: TimeStep, state: DynamicsState): """ Args: time_step (TimeStep): input data for dynamics learning state (DynamicsState): state for dynamics learning (previous observation) Returns: AlgStep: output: state (DynamicsState): state for training info (DynamicsInfo): """ raise NotImplementedError
[docs] def calc_loss(self, info: DynamicsInfo): # Here we take mean over the loss to avoid undesired additional # masking from base algorithm's ``update_with_gradient``. scalar_loss = nest.map_structure(torch.mean, info.loss) return LossInfo(scalar_loss=scalar_loss.loss, extra=scalar_loss.loss)
[docs]@alf.configurable class DeterministicDynamicsAlgorithm(DynamicsLearningAlgorithm): """Deterministic Dynamics Learning Module This module trys to learn the dynamics of environment with a determinstic model. """ def __init__(self, action_spec, feature_spec, hidden_size=256, num_replicas=1, dynamics_network_ctor: Optional[ Callable[[Any, Any], DynamicsNetwork]] = None, name="DeterministicDynamicsAlgorithm"): """Create a DeterministicDynamicsAlgorithm. Args: hidden_size (int|tuple): size of hidden layer(s) num_replicas (int): number of network replicas to be used in the ensemble for dynamics learning dynamics_network_ctor: Used to construct a network for predicting the change of the next feature based on the previous feature and action. It should accept input with spec of the format [feature_spec, encoded_action_spec] and output a tensor of the shape feature_spec. For discrete action case, encoded_action is a one-hot representation of the action. For continuous action, encoded action is the original action. """ dynamics_network = None if dynamics_network_ctor is not None: dynamics_network = dynamics_network_ctor( input_tensor_spec=(feature_spec, action_spec), output_tensor_spec=feature_spec) if dynamics_network is not None: dynamics_network_state_spec = dynamics_network.state_spec if num_replicas > 1: ens_feature_spec = TensorSpec( (num_replicas, feature_spec.shape[0]), dtype=torch.float32) else: ens_feature_spec = feature_spec super().__init__( train_state_spec=DynamicsState( feature=ens_feature_spec, network=dynamics_network_state_spec), action_spec=action_spec, feature_spec=feature_spec, num_replicas=num_replicas, hidden_size=hidden_size, dynamics_network=dynamics_network, name=name) def _expand_to_replica(self, inputs, spec): """Expand the inputs of shape [B, ...] to [B, n, ...] if n > 1, where n is the number of replicas. When n = 1, the unexpanded inputs will be returned. Args: inputs (Tensor): the input tensor to be expanded spec (TensorSpec): the spec of the unexpanded inputs. It is used to determine whether the inputs is already an expanded one. If it is already expanded, inputs will be returned without any further processing. Returns: Tensor: the expaneded inputs or the original inputs. """ outer_rank = get_outer_rank(inputs, spec) if outer_rank == 1 and self._num_replicas > 1: return inputs.unsqueeze(1).expand(-1, self._num_replicas, *inputs.shape[1:]) else: return inputs
[docs] def predict_step(self, time_step: TimeStep, state: DynamicsState): """Predict the next observation given the current time_step. The next step is predicted using the ``prev_action`` from time_step and the ``feature`` from state. Args: time_step (TimeStep): time step structure. The ``prev_action`` from time_step will be used for predicting feature of the next step. It should be a Tensor of the shape [B, ...], or [B, n, ...] when n > 1, where n denotes the number of dynamics network replicas. When the input tensor has the shape of [B, ...] and n > 1, it will be first expanded to [B, n, ...] to match the number of dynamics network replicas. state (DynamicsState): state for dynamics learning with the following fields: - feature (Tensor): features of the previous observation of the shape [B, ...], or [B, n, ...] when n > 1. When ``state.feature`` has the shape of [B, ...] and n > 1, it will be first expanded to [B, n, ...] to match the number of dynamics network replicas. It is used for predicting the feature of the next step together with ``time_step.prev_action``. - network: the input state of the dynamics network Returns: AlgStep: outputs (Tensor): predicted feature of the next step, of the shape [B, ...], or [B, n, ...] when n > 1. state (DynamicsState): with the following fields - feature (Tensor): [B, n, ...] (or [B, n, ...] when n > 1) shape tensor representing the predicted feature of the next step - network: the updated state of the dynamics network info: empty tuple () """ action = self._encode_action(time_step.prev_action) obs = state.feature # perform preprocessing observations = self._expand_to_replica(obs, self._feature_spec) actions = self._expand_to_replica(action, self._action_spec) forward_deltas, network_state = self._dynamics_network( (observations, actions), state=state.network) forward_pred = observations + forward_deltas state = state._replace(feature=forward_pred, network=network_state) return AlgStep(output=forward_pred, state=state, info=())
[docs] def update_state(self, time_step: TimeStep, state: DynamicsState): """Update the state based on TimeStep data. This function is mainly used during rollout together with a planner. This function is necessary as we need to update the feature in DynamicsState with those of the current observation, after each step of rollout. Args: time_step (TimeStep): input data for dynamics learning state (DynamicsState): state for DeterministicDynamicsAlgorithm (previous observation) Returns: state (DynamicsState): updated dynamics state """ feature = time_step.observation if feature.shape == state.feature.shape: updated_state = state._replace(feature=feature) else: # feature [B, d], state.feature: [B, n, d] updated_state = state._replace( feature=self._expand_to_replica(feature, self._feature_spec)) return updated_state
[docs] def train_step(self, time_step: TimeStep, state: DynamicsState): """ Args: time_step (TimeStep): time step structure. The ``prev_action`` from time_step will be used for predicting feature of the next step. It should be a Tensor of the shape [B, ...] or [B, n, ...] when n > 1, where n denotes the number of dynamics network replicas. When the input tensor has the shape of [B, ...] and n > 1, it will be first expanded to [B, n, ...] to match the number of dynamics network replicas. state (DynamicsState): state for dynamics learning with the following fields: - feature (Tensor): features of the previous observation of the shape [B, ...] or [B, n, ...] when n > 1. When ``state.feature`` has the shape of [B, ...] and n > 1, it will be first expanded to [B, n, ...] to match the number of dynamics network replicas. It is used for predicting the feature of the next step together with ``time_step.prev_action``. - network: the input state of the dynamics network Returns: AlgStep: outputs: empty tuple () state (DynamicsState): with the following fields - feature (Tensor): [B, ...] (or [B, n, ...] when n > 1) shape tensor representing the predicted feature of the next step - network: the updated state of the dynamics network info (DynamicsInfo): with the following fields being updated: - loss (LossInfo): """ feature = time_step.observation feature = self._expand_to_replica(feature, self._feature_spec) dynamics_step = self.predict_step(time_step, state) forward_pred = dynamics_step.output forward_loss = (feature - forward_pred)**2 if forward_loss.ndim > 2: # [B, n, ...] -> [B, ...] forward_loss = forward_loss.sum(1) if forward_loss.ndim > 1: forward_loss = 0.5 * forward_loss.mean( list(range(1, forward_loss.ndim))) # we mask out FIRST as its state is invalid valid_masks = (time_step.step_type != StepType.FIRST).to(torch.float32) forward_loss = forward_loss * valid_masks info = DynamicsInfo( loss=LossInfo( loss=forward_loss, extra=dict(forward_loss=forward_loss))) state = state._replace(feature=feature) return AlgStep(output=(), state=state, info=info)
[docs]@alf.configurable class StochasticDynamicsAlgorithm(DeterministicDynamicsAlgorithm): """Stochastic Dynamics Learning Module This module learns the dynamics of environment with a stochastic model. """ def __init__(self, action_spec, feature_spec, hidden_size=256, num_replicas=1, dynamics_network_ctor: Optional[ Callable[[Any, Any], DynamicsNetwork]] = None, name="StochasticDynamicsAlgorithm"): """Create a StochasticDynamicsAlgorithm. Args: hidden_size (int|tuple): size of hidden layer(s) num_replicas (int): number of network replicas to be used in the ensemble for dynamics learning dynamics_network_ctor: used to construct network for predicting next feature based on the previous feature and action. It should accept input with spec [feature_spec, encoded_action_spec] and output a tensor of shape feature_spec. For discrete action, encoded_action is an one-hot representation of the action. For continuous action, encoded action is the original action. """ super().__init__( action_spec=action_spec, feature_spec=feature_spec, hidden_size=hidden_size, num_replicas=num_replicas, dynamics_network_ctor=dynamics_network_ctor) assert self._dynamics_network._prob, "should use probabilistic network"
[docs] def predict_step(self, time_step: TimeStep, state: DynamicsState): """Predict the next observation given the current time_step. The next step is predicted using the ``prev_action`` from time_step and the ``feature`` from state. Args: time_step (TimeStep): time step structure. The ``prev_action`` from time_step will be used for predicting feature of the next step. It should be a Tensor of the shape [B, ...], or [B, n, ...] when n > 1, where n denotes the number of dynamics network replicas. When the input tensor has the shape of [B, ...] and n > 1, it will be first expanded to [B, n, ...] to match the number of dynamics network replicas. state (DynamicsState): state for dynamics learning with the following fields: - feature (Tensor): features of the previous observation of the shape [B, ...], or [B, n, ...] when n > 1. When ``state.feature`` has the shape of [B, ...] and n > 1, it will be first expanded to [B, n, ...] to match the number of dynamics network replicas. It is used for predicting the feature of the next step together with ``time_step.prev_action``. - network: the input state of the dynamics network Returns: AlgStep: outputs (Tensor): predicted feature of the next step, of the shape [B, ...], or [B, n, ...] when n > 1. state (DynamicsState): with the following fields - feature (Tensor): [B, n, ...] (or [B, n, ...] when n > 1) shape tensor representing the predicted feature of the next step - network: the updated state of the dynamics network info (DynamicsInfo): with the following fields being updated: - dist (td.Distribution): the predictive distribution which can be used for further calculation or summarization. """ action = self._encode_action(time_step.prev_action) obs = state.feature # perform preprocessing observations = self._expand_to_replica(obs, self._feature_spec) actions = self._expand_to_replica(action, self._action_spec) dist, network_states = self._dynamics_network((observations, actions), state=state.network) forward_deltas = dist.sample() forward_preds = observations + forward_deltas state = state._replace(feature=forward_preds, network=network_states) return AlgStep( output=forward_preds, state=state, info=DynamicsInfo(dist=dist))
[docs] def train_step(self, time_step: TimeStep, state: DynamicsState): """ Args: time_step (TimeStep): time step structure. The ``prev_action`` from time_step will be used for predicting feature of the next step. It should be a Tensor of the shape [B, ...] or [B, n, ...] when n > 1, where n denotes the number of dynamics network replicas. When the input tensor has the shape of [B, ...] and n > 1, it will be first expanded to [B, n, ...] to match the number of dynamics network replicas. state (DynamicsState): state for dynamics learning with the following fields: - feature (Tensor): features of the previous observation of the shape [B, ...] or [B, n, ...] when n > 1. When ``state.feature`` has the shape of [B, ...] and n > 1, it will be first expanded to [B, n, ...] to match the number of dynamics network replicas. It is used for predicting the feature of the next step together with ``time_step.prev_action``. - network: the input state of the dynamics network Returns: AlgStep: outputs: empty tuple () state (DynamicsState): with the following fields - feature (Tensor): [B, ...] (or [B, n, ...] when n > 1) shape tensor representing the predicted feature of the next step - network: the updated state of the dynamics network info (DynamicsInfo): with the following fields being updated: - loss (LossInfo): - dist (td.Distribution): the predictive distribution which can be used for further calculation or summarization. """ feature = time_step.observation feature = self._expand_to_replica(feature, self._feature_spec) dynamics_step = self.predict_step(time_step, state) dist = dynamics_step.info.dist forward_loss = -dist.log_prob(feature - state.feature) if forward_loss.ndim > 2: # [B, n, ...] -> [B, ...] forward_loss = forward_loss.sum(1) if forward_loss.ndim > 1: forward_loss = forward_loss.mean(list(range(1, forward_loss.ndim))) valid_masks = (time_step.step_type != StepType.FIRST).to(torch.float32) forward_loss = forward_loss * valid_masks info = DynamicsInfo( loss=LossInfo( loss=forward_loss, extra=dict(forward_loss=forward_loss)), dist=dist) state = state._replace(feature=feature) return AlgStep(output=(), state=state, info=info)