Source code for alf.networks.critic_networks

# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CriticNetworks"""

import functools
import math

import torch

import alf
import alf.utils.math_ops as math_ops
import alf.nest as nest
from alf.initializers import variance_scaling_init
from alf.tensor_specs import TensorSpec

from .encoding_networks import EncodingNetwork, LSTMEncodingNetwork, ParallelEncodingNetwork


def _check_action_specs_for_critic_networks(
        action_spec, action_input_processors, action_preprocessing_combiner):

    if len(nest.flatten(action_spec)) > 1:
        assert action_preprocessing_combiner is not None, (
            "An action combiner is needed when there are multiple action specs:"
            " {}".format(action_spec))

    def _check_individual(spec, proc):
        if spec.is_discrete:
            assert proc is not None, (
                'CriticNetwork only supports continuous actions. One of given '
                + 'action specs {} is discrete. Use QNetwork instead. '.format(
                    spec) +
                'Alternatively, specify `action_input_processors` to transform '
                + 'discrete actions to continuous action embeddings first.')

    if action_input_processors is None:
        action_input_processors = nest.map_structure(lambda _: None,
                                                     action_spec)

    nest.map_structure(_check_individual, action_spec, action_input_processors)


[docs]@alf.configurable class CriticNetwork(EncodingNetwork): """Creates an instance of ``CriticNetwork`` for estimating action-value of continuous or discrete actions. The action-value is defined as the expected return starting from the given input observation and taking the given action. This module takes observation as input and action as input and outputs an action-value tensor with the shape of ``[batch_size]``. The network take a tuple of (observation, action) as input to computes the action-value given an observation. """ def __init__(self, input_tensor_spec, output_tensor_spec=TensorSpec(()), observation_input_processors=None, observation_preprocessing_combiner=None, observation_conv_layer_params=None, observation_fc_layer_params=None, action_input_processors=None, action_preprocessing_combiner=None, action_fc_layer_params=None, observation_action_combiner=None, joint_fc_layer_params=None, activation=torch.relu_, kernel_initializer=None, use_fc_bn=False, use_naive_parallel_network=False, name="CriticNetwork"): """ Args: input_tensor_spec: A tuple of ``TensorSpec``s ``(observation_spec, action_spec)`` representing the inputs. output_tensor_spec (TensorSpec): spec for the output observation_input_preprocessors (nested Network|nn.Module|None): a nest of input preprocessors, each of which will be applied to the corresponding observation input. observation_preprocessing_combiner (NestCombiner): preprocessing called on complex observation inputs. observation_conv_layer_params (tuple[tuple]): a tuple of tuples where each tuple takes a format ``(filters, kernel_size, strides, padding)``, where ``padding`` is optional. observation_fc_layer_params (tuple[int]): a tuple of integers representing hidden FC layer sizes for observations. action_input_processors (nested Network|nn.Module|None): a nest of input preprocessors, each of which will be applied to the corresponding action input. action_preprocessing_combiner (NestCombiner): preprocessing called to combine complex action inputs. action_fc_layer_params (tuple[int]): a tuple of integers representing hidden FC layer sizes for actions. observation_action_combiner (NestCombiner): combiner class for fusing the observation and action. If None, ``NestConcat`` will be used. joint_fc_layer_params (tuple[int]): a tuple of integers representing hidden FC layer sizes FC layers after merging observations and actions. activation (nn.functional): activation used for hidden layers. The last layer will not be activated. kernel_initializer (Callable): initializer for all the layers but the last layer. If none is provided a variance_scaling_initializer with uniform distribution will be used. use_fc_bn (bool): whether use Batch Normalization for the internal FC layers (i.e. FC layers beside the last one). use_naive_parallel_network (bool): if True, will use ``NaiveParallelNetwork`` when ``make_parallel`` is called. This might be useful in cases when the ``NaiveParallelNetwork`` has an advantange in terms of speed over ``ParallelNetwork``. You have to test to see which way is faster for your particular situation. name (str): """ if kernel_initializer is None: kernel_initializer = functools.partial( variance_scaling_init, gain=math.sqrt(1.0 / 3), mode='fan_in', distribution='uniform') observation_spec, action_spec = input_tensor_spec obs_encoder = EncodingNetwork( observation_spec, input_preprocessors=observation_input_processors, preprocessing_combiner=observation_preprocessing_combiner, conv_layer_params=observation_conv_layer_params, fc_layer_params=observation_fc_layer_params, activation=activation, kernel_initializer=kernel_initializer, use_fc_bn=use_fc_bn, name=name + ".obs_encoder") _check_action_specs_for_critic_networks(action_spec, action_input_processors, action_preprocessing_combiner) action_encoder = EncodingNetwork( action_spec, input_preprocessors=action_input_processors, preprocessing_combiner=action_preprocessing_combiner, fc_layer_params=action_fc_layer_params, activation=activation, kernel_initializer=kernel_initializer, use_fc_bn=use_fc_bn, name=name + ".action_encoder") last_kernel_initializer = functools.partial( torch.nn.init.uniform_, a=-0.003, b=0.003) if observation_action_combiner is None: observation_action_combiner = alf.layers.NestConcat(dim=-1) super().__init__( input_tensor_spec=input_tensor_spec, output_tensor_spec=output_tensor_spec, input_preprocessors=(obs_encoder, action_encoder), preprocessing_combiner=observation_action_combiner, fc_layer_params=joint_fc_layer_params, activation=activation, kernel_initializer=kernel_initializer, last_layer_size=output_tensor_spec.numel, last_activation=math_ops.identity, use_fc_bn=use_fc_bn, last_kernel_initializer=last_kernel_initializer, name=name) self._use_naive_parallel_network = use_naive_parallel_network
[docs] def make_parallel(self, n): """Create a parallel critic network using ``n`` replicas of ``self``. The initialized network parameters will be different. If ``use_naive_parallel_network`` is True, use ``NaiveParallelNetwork`` to create the parallel network. """ if self._use_naive_parallel_network: return alf.networks.NaiveParallelNetwork(self, n) else: return super().make_parallel(n, True)
[docs]@alf.configurable class CriticRNNNetwork(LSTMEncodingNetwork): """Creates an instance of ``CriticRNNNetwork`` for estimating action-value of continuous or discrete actions. The action-value is defined as the expected return starting from the given inputs (observation and state) and taking the given action. It takes observation and state as input and outputs an action-value tensor with the shape of [batch_size]. """ def __init__(self, input_tensor_spec, output_tensor_spec=TensorSpec(()), observation_input_processors=None, observation_preprocessing_combiner=None, observation_conv_layer_params=None, observation_fc_layer_params=None, action_input_processors=None, action_preprocessing_combiner=None, action_fc_layer_params=None, joint_fc_layer_params=None, lstm_hidden_size=100, critic_fc_layer_params=None, activation=torch.relu_, kernel_initializer=None, name="CriticRNNNetwork"): """ Args: input_tensor_spec: A tuple of ``TensorSpec``s ``(observation_spec, action_spec)`` representing the inputs. ourput_tensor_spec (TensorSpec): spec for the output observation_input_preprocessors (nested Network|nn.Module|None): a nest of input preprocessors, each of which will be applied to the corresponding observation input. observation_preprocessing_combiner (NestCombiner): preprocessing called on complex observation inputs. observation_conv_layer_params (tuple[tuple]): a tuple of tuples where each tuple takes a format ``(filters, kernel_size, strides, padding)``, where ``padding`` is optional. observation_fc_layer_params (tuple[int]): a tuple of integers representing hidden FC layer sizes for observations. action_input_processors (nested Network|nn.Module|None): a nest of input preprocessors, each of which will be applied to the corresponding action input.a action_preprocessing_combiner (NestCombiner): preprocessing called to combine complex action inputs. action_fc_layer_params (tuple[int]): a tuple of integers representing hidden FC layer sizes for actions. joint_fc_layer_params (tuple[int]): a tuple of integers representing hidden FC layer sizes FC layers after merging observations and actions. lstm_hidden_size (int or tuple[int]): the hidden size(s) of the LSTM cell(s). Each size corresponds to a cell. If there are multiple sizes, then lstm cells are stacked. critic_fc_layer_params (tuple[int]): a tuple of integers representing hidden FC layers that are applied after the lstm cell's output. activation (nn.functional): activation used for hidden layers. The last layer will not be activated. kernel_initializer (Callable): initializer for all the layers but the last layer. If none is provided a ``variance_scaling_initializer`` with uniform distribution will be used. name (str): """ if kernel_initializer is None: kernel_initializer = functools.partial( variance_scaling_init, gain=math.sqrt(1.0 / 3), mode='fan_in', distribution='uniform') observation_spec, action_spec = input_tensor_spec obs_encoder = EncodingNetwork( observation_spec, input_preprocessors=observation_input_processors, preprocessing_combiner=observation_preprocessing_combiner, conv_layer_params=observation_conv_layer_params, fc_layer_params=observation_fc_layer_params, activation=activation, kernel_initializer=kernel_initializer) _check_action_specs_for_critic_networks(action_spec, action_input_processors, action_preprocessing_combiner) action_encoder = EncodingNetwork( action_spec, input_preprocessors=action_input_processors, preprocessing_combiner=action_preprocessing_combiner, fc_layer_params=action_fc_layer_params, activation=activation, kernel_initializer=kernel_initializer) last_kernel_initializer = functools.partial( torch.nn.init.uniform_, a=-0.003, b=0.003) super().__init__( input_tensor_spec=input_tensor_spec, output_tensor_spec=output_tensor_spec, input_preprocessors=(obs_encoder, action_encoder), preprocessing_combiner=alf.layers.NestConcat(dim=-1), pre_fc_layer_params=joint_fc_layer_params, hidden_size=lstm_hidden_size, post_fc_layer_params=critic_fc_layer_params, activation=activation, kernel_initializer=kernel_initializer, last_layer_size=output_tensor_spec.numel, last_activation=math_ops.identity, last_kernel_initializer=last_kernel_initializer)
[docs] def make_parallel(self, n): """Create a parallel critic RNN network using ``n`` replicas of ``self``. The initialized network parameters will be different. If ``use_naive_parallel_network`` is True, use ``NaiveParallelNetwork`` to create the parallel network. """ return super().make_parallel(n, True)