Source code for alf.networks.mdq_critic_networks

# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MdqCriticNetworks"""

import functools
import math
import numpy as np

import torch
import torch.nn.functional as f
import torch.nn as nn

import alf
import alf.layers as layers
import alf.nest as nest
from alf.networks import Network, EncodingNetwork, ParallelEncodingNetwork
from alf.initializers import variance_scaling_init
from alf.tensor_specs import TensorSpec, BoundedTensorSpec
from alf.utils import common, spec_utils, tensor_utils
import alf.utils.math_ops as math_ops
from alf.utils.action_quantizer import ActionQuantizer


[docs]@alf.configurable class MdqCriticNetwork(Network): """Create an instance of MdqCriticNetwork for estimating action-value of continuous actions and action sampling used in the MDQ algorithm.""" def __init__(self, input_tensor_spec, action_qt: ActionQuantizer = None, num_critic_replicas=2, obs_encoding_layer_params=None, pre_encoding_layer_params=None, mid_encoding_layer_params=None, post_encoding_layer_params=None, free_form_fc_layer_params=None, activation=torch.relu_, kernel_initializer=None, debug_summaries=False, name="MdqCriticNetwork"): """Creates an instance of `MdqCriticNetwork` for estimating action-value of continuous actions and action sampling. Currently there are two branches of networks: - free-form branch: a plain MLP for Q-learning - adv-form branch: an advantage form of network for action generation. It is trained by a target from the free-form net. The adv-form branch has the following structures for flexibility: obs -> [obs_encoding_net] -> encoded_obs encoded_obs, action -> [pre_encoding_nets] -> [mid_shared_encoding_nets] -> [post_encoding_nets] -> outputs where the pre_encoding_nets and post_encoding_nets do not share parameters across action dimensions while mid_shared_encoding_nets shares parameters across action dimensions. If the encoding_layer_params for a sub-net is None, that sub-net is effectively neglected. Furthermore, to enable parallel computation across action dimension in the case of value computation, we have both parallel and individual versions for the nets without parameter sharing. For exmaple, for post_encoding_nets, we also have post_encoding_parallel_net, which is essentially the equivalent form of post_encoding_nets but supports parallel forwarding. The parameters of the two versions are synced. The partial actions (a[0:i]) are zero-padded for both parallel and individual networks to enable parallel computation. For conciseness purpose, the following notations will be used when convenient: - B: batch size - d: dimensionality of feature - n: number of network replica - action_dim: the dimensionality of actions - action_bin: number of discrete bins for each action dim Args: input_tensor_spec: A tuple of TensorSpecs (observation_spec, action_spec) representing the inputs. action_qt (ActionQuantizer): action quantization module num_critic_replicas (int): number of critic networks obs_encoding_layer_params (tuple[int]): a tuple of integers representing hidden FC layer sizes for encoding observations. pre_encoding_layer_params (tuple[int]): a tuple of integers representing hidden FC layer sizes for encoding concatenated [encoded_observation, actions]. Parameters are not shared across action dimensions mid_encoding_layer_params (tuple[int]): a tuple of integers representing hidden FC layer for further encoding the outputs from pre_encoding_net. The parameters are shared across action dimensions. post_encoding_layer_params (tuple[int]): a tuple of integers representing hidden FC layer for further encoding the outputs from mid_encoding_net. The parameters are not shared across action dimensions. free_form_fc_layer_params (tuple[int]): a tuple of integers representing hidden FC layer for Q-learning. We refer it as the free form to differentiate it from the mdq-form of network which is structured. activation (nn.functional): activation used for hidden layers. The last layer will not be activated. kernel_initializer (Callable): initializer for all the layers but the last layer. If none is provided a variance_scaling_initializer with uniform distribution will be used. name (str): """ super().__init__(input_tensor_spec, name=name) observation_spec, action_spec = input_tensor_spec flat_action_spec = nest.flatten(action_spec) if len(flat_action_spec) > 1: raise ValueError( 'Only a single action is supported by this network') self._single_action_spec = flat_action_spec[0] if action_qt is None: action_qt = ActionQuantizer(action_spec, "uniform", 15) self._action_qt = action_qt self._action_bins = self._action_qt._action_bins # the logpi of the uniform prior used for KL computation self._log_pi_uniform_prior = -np.log(self._action_bins) self._action_dim = action_spec.shape[0] # control vector dim self._num_critic_replicas = num_critic_replicas self._obs_encoding_net = ParallelEncodingNetwork( observation_spec, self._num_critic_replicas, fc_layer_params=obs_encoding_layer_params, activation=activation, kernel_initializer=kernel_initializer) last_activation = math_ops.identity last_kernel_initializer = functools.partial(torch.nn.init.uniform_, \ a=-0.003, b=0.003) in_size = self._action_dim self._pre_encoding_nets = [] for i in range(self._action_dim): # output_spec.shape: [n, d] self._pre_encoding_nets.append( ParallelEncodingNetwork( TensorSpec((self._obs_encoding_net.output_spec.shape[-1] + in_size, )), self._num_critic_replicas, fc_layer_params=pre_encoding_layer_params, activation=activation, kernel_initializer=kernel_initializer)) # parallel along both critic and action dims without sharing parameters # for each action dimension. # input: [B, action_dim*n, d]: need to stack over dim1 # output: [B, action_dim*n, d']: need to unstack over dim1 for # splitting over networks self._pre_encoding_parallel_net = ParallelEncodingNetwork( TensorSpec( (self._obs_encoding_net.output_spec.shape[-1] + in_size, )), self._num_critic_replicas * self._action_dim, fc_layer_params=pre_encoding_layer_params, activation=activation, kernel_initializer=kernel_initializer) # parallel along both critic and action dims with sharing parameters # for each action dimension. # input: [action_dim*B, n, d]: need to stack over dim0 # output: [action_dim*B, n, d']: need to unstack over dim0 for # splitting over networks self._mid_shared_encoding_nets = ParallelEncodingNetwork( TensorSpec( (self._pre_encoding_parallel_net.output_spec.shape[-1], )), self._num_critic_replicas, fc_layer_params=mid_encoding_layer_params, activation=activation, kernel_initializer=kernel_initializer) out_size = self._mid_shared_encoding_nets.output_spec.shape[-1] post_enc_out_size = self._action_qt.action_bins self._post_encoding_nets = [] for i in range(self._action_dim): self._post_encoding_nets.append( ParallelEncodingNetwork( TensorSpec((out_size, )), self._num_critic_replicas, fc_layer_params=post_encoding_layer_params, activation=activation, kernel_initializer=kernel_initializer, last_layer_size=post_enc_out_size, last_activation=last_activation, last_kernel_initializer=last_kernel_initializer)) # parallel along both critic and action dims without sharing parameters # for each action dimension. # input: [B, action_dim*n, d]: need to stack over dim1 # output: [B, action_dim*n, d']: need to unstack over dim1 for # splitting over networks self._post_encoding_parallel_net = ParallelEncodingNetwork( TensorSpec((out_size, )), self._num_critic_replicas * self._action_dim, fc_layer_params=post_encoding_layer_params, activation=activation, kernel_initializer=kernel_initializer, last_layer_size=post_enc_out_size, last_activation=last_activation, last_kernel_initializer=last_kernel_initializer) assert free_form_fc_layer_params is not None self._free_form_q_net = ParallelEncodingNetwork( TensorSpec((observation_spec.shape[-1] + self._action_dim, )), self._num_critic_replicas, fc_layer_params=free_form_fc_layer_params, activation=activation, kernel_initializer=kernel_initializer, last_layer_size=1, last_activation=math_ops.identity, last_kernel_initializer=last_kernel_initializer) MdqCriticNetwork._parallel_to_individual_network_sync( self._pre_encoding_parallel_net, self._pre_encoding_nets, step=self._num_critic_replicas) MdqCriticNetwork._parallel_to_individual_network_sync( self._post_encoding_parallel_net, self._post_encoding_nets, step=self._num_critic_replicas) self._output_spec = TensorSpec(()) self._debug_summaries = debug_summaries
[docs] @torch.no_grad() def get_action(self, inputs, alpha, greedy): """Sample action from the distribution induced by the mdq-net. Args: inputs: A tuple of Tensors consistent with `input_tensor_spec` alpha: the temperature used for the advantage computation greedy (bool): If True, do greedy sampling by taking the mode of the distribution. If False, do direct sampling from the distribution. Returns: actions (torch.Tensor): a tensor of the shape [B, n, action_dim] log_pi_per_dim (torch.Tensor): a tensor of the shape [B, n, action_dim] representing the log_pi for each dimension of the sampled multi-dimensional action """ observations = inputs # [B, n, d] t_shape = (observations.shape[0], self._num_critic_replicas, self._action_dim) actions = torch.zeros(t_shape) log_pi_per_dim = torch.zeros(t_shape) # [B, n, d] encoded_obs, _ = self._obs_encoding_net(observations) if actions.ndim == 2: actions = tensor_utils.tensor_extend_new_dim( actions, dim=1, n=self._num_critic_replicas) action_padded = torch.zeros(t_shape) for i in range(self._action_dim): action_padded[..., 0:i] = actions[..., 0:i] joint = torch.cat((encoded_obs, action_padded.detach()), -1) action_values_i, _ = self._net_forward_individual(joint, alpha, i) trans_action_values_i = self._transform_action_value( action_values_i, alpha) sampled_indices, sampled_log_pi = self._sample_action_from_value( trans_action_values_i / alpha, alpha, greedy) # convert index to action actions[..., i] = self._action_qt.ind_to_action(sampled_indices) log_pi_per_dim[..., i] = sampled_log_pi return actions, log_pi_per_dim
[docs] def forward(self, inputs, alpha, state=(), free_form=False): """Computes action-value given an observation. Args: inputs: A tuple of Tensors consistent with `input_tensor_spec` alpha: the temperature used for the advantage computation state: empty for API consistenty free_form (bool): use the free-form branch for computation if True; default value is False Returns: Q_values (torch.Tensor): - if free_form is True, its shape is [B, n] - if free_form is False, its shape is [B, n, action_dim] state: empty """ if free_form: Q_values, state = self._free_form_q_net(inputs) Q_values = Q_values.squeeze(2) return Q_values, state observations, actions = inputs # observations: [B, d] # encoded_obs: [B, n, d'] # Note that when obs_encoding_net is a dummy network # (i.e., layer_params is None), d' is the same as d. encoded_obs, _ = self._obs_encoding_net(observations) if actions.ndim == 2: # [B, action_dim] -> [B, n, action_dim] actions = tensor_utils.tensor_extend_new_dim( actions, dim=1, n=self._num_critic_replicas) # [B, n, action_dim] t_shape = (observations.shape[0], self._num_critic_replicas, self._action_dim) # [action_dim, B, n, 1] Q_values = torch.zeros(self._action_dim, observations.shape[0], self._num_critic_replicas, 1) joint = torch.empty(0) action_padded = torch.zeros(t_shape) # prepare parallel-forwarding inputs inputs_per_dim = [] for i in range(self._action_dim): action_padded[..., 0:i] = actions[..., 0:i] # concat (obs, action) for each action dimension inputs_per_dim.append( torch.cat((encoded_obs, action_padded.detach()), dim=-1)) # concat per dim input batch to a joint batch along dim1 # [B, action_dim*n, d] joint = torch.cat(inputs_per_dim, dim=1) # forward the joint batch # action_values_per_dim: [action_dim, B, n, action_bin] action_values_per_dim, _ = self._net_forward_parallel( joint, alpha, batch_size=observations.shape[0]) trans_action_values_per_dim = self._transform_action_value( action_values_per_dim, alpha) for i in range(self._action_dim): action_ind = self._action_qt.action_to_ind(actions[..., i]) if i == 0: action_value_i = self._batched_index_select( action_values_per_dim[i], -1, action_ind.long()) Q_values[i] = action_value_i # KL-divergence Q_values[i] = Q_values[i] - alpha * self._log_pi_uniform_prior else: selected_trans_action_value_i = self._batched_index_select( trans_action_values_per_dim[i], -1, action_ind.long()) Q_values[i] = Q_values[i - 1] + selected_trans_action_value_i # KL-divergence Q_values[i] = Q_values[i] - alpha * self._log_pi_uniform_prior # [action_dim, B, n, 1] -> [B, n, action_dim] Q_values = Q_values.squeeze(3).permute(1, 2, 0) return Q_values, state
def _net_forward_individual(self, inputs, alpha, i, state=()): """Individiual forwarding for a specified action dims for value computation. Args: inputs (torch.Tensor): a tensor of the shape [B, n, d] alpha: the temperature used for the advantage computation i (int): the specified action dim to perform forwarding Returns: action_values_i (torch.Tensor): a tensor of the shape [B, n, action_bin] state: empty """ inputs, _ = self._pre_encoding_nets[i](inputs) action_values_i, state = self._mid_shared_encoding_nets(inputs) action_values_i, state = self._post_encoding_nets[i](action_values_i) return action_values_i, state def _net_forward_parallel(self, inputs, alpha, batch_size, state=()): """Parallel forwarding across action dims for value computation. Args: inputs (torch.Tensor): a tensor of the shape [B, action_dim*n, d] with the data for each action dimension concanated along the dim1 for parallel computation alpha: the temperature used for the advantage computation batch_size: the size of the original batch without stacking all action dimensions Returns: action_values (torch.Tensor): a tensor of the shape [action_dim, B, n, action_bin] state: empty """ # [B, action_dim*n, d] action_values_pre, _ = self._pre_encoding_parallel_net(inputs) # [B, action_dim*n, d] -> [action_dim*B, n, d] action_values_pre = self._reshape_from_ensemble_to_batch( action_values_pre, batch_size) action_values_mid, state = self._mid_shared_encoding_nets( action_values_pre) # [action_dim*B, n, d] -> [B, action_dim*n, d] action_values_mid = self._reshape_from_batch_to_ensemble( action_values_mid, batch_size) action_values_final, _ = self._post_encoding_parallel_net( action_values_mid) # [B, action_dim*n, d]-> [B, action_dim, n, d] -> [action_dim, B, n, d] action_values = action_values_final.view(batch_size, self._action_dim, self._num_critic_replicas, -1).transpose(0, 1) return action_values, state def _reshape_from_batch_to_ensemble(self, joint_batch, batch_size): """Reshape the joint batch of the shape [action_dim*B, n, d] to [B, action_dim*n, d], i.e., separate and move the action dimension axis from the batch dimension (dim0) to the ensemble dimension (dim1) Args: joint_batch (torch.Tensor): a tensor of the shape [action_dim*B, n, d] with the data for each action dimension concanated along the batch dimension (dim0) batch_size: the size of the original batch without stacking all action dimensions Returns: reshaped_batch (torch.Tensor): a tensor of the shape [B, action_dim*n, d] """ assert len(joint_batch.shape) == 3 and joint_batch.shape[:-1] == \ (self._action_dim * batch_size, self._num_critic_replicas) d = joint_batch.shape[-1] # [action_dim*B, n, d] -> [action_dim, B, n, d] reshaped_batch = joint_batch.view(self._action_dim, batch_size, self._num_critic_replicas, d) # [action_dim, B, n, d] -> [B, action_dim, n, d] -> [B, action_dim*n, d] reshaped_batch = reshaped_batch.transpose(0, 1).reshape( batch_size, -1, d) return reshaped_batch def _reshape_from_ensemble_to_batch(self, joint_batch, batch_size): """Reshape the joint batch of the shape [B, action_dim*n, d] to [action_dim*B, n, d], i.e., separate and move the action dimension axis from the ensemble dimension (dim1) to the batch dimension (dim0) Args: joint_batch (torch.Tensor): a tensor of the shape [B, action_dim*n, d] with the data for each action dimension concanated along the ensemble dimension (dim1) batch_size: the size of the original batch without stacking all action dimensions Returns: reshaped_batch (torch.Tensor): a tensor of the shape [action_dim*B, n, d] """ assert len(joint_batch.shape) == 3 and joint_batch.shape[:-1] == \ (batch_size, self._action_dim * self._num_critic_replicas) d = joint_batch.shape[-1] # [B, action_dim*n, d] -> [B, action_dim, n, d] reshaped_batch = joint_batch.view(batch_size, self._action_dim, self._num_critic_replicas, d) # [B, action_dim, n, d] -> [action_dim, B, n, d] -> [action_dim*B, n, d] reshaped_batch = reshaped_batch.transpose(0, 1).reshape( -1, self._num_critic_replicas, d) return reshaped_batch def _transform_action_value(self, action_values, alpha): """Transform raw action values to valid alpha * log_pi Args: action_values (torch.Tensor): raw action values computed from a network, with the last dim as the distribution dimension alpha: the temperature used for the transformation Returns: transformed_value (torch.Tensor): a tensor with value equals alpha * log_pi computed from input action_values """ v_value = alpha * torch.logsumexp( action_values / alpha, dim=-1, keepdim=True) transformed_value = action_values - v_value return transformed_value def _sample_action_from_value(self, logits, alpha, greedy=False): """Sample discrete action from given logits Args: logits (torch.Tensor): log pi of the discrete distribution with the last dim as the distribution dimension alpha: the temperature used for the transformation greedy (bool): if True, do greedy sampling by taking the mode of the distribution; otherwise, sample according to the probability of the distribution Returns: sampled_ind (torch.Tensor): the indices of the sampled action sampled_log_pi (torch.Tensor): the log prob of the sampled action """ if greedy: sampled_log_pi, sampled_ind = torch.max(logits, dim=-1) else: batch_size = logits.shape[0] # logits [B, n, d] -> [B*n, d] batched_logits = logits.reshape(-1, self._action_bins) dist = torch.distributions.categorical.Categorical( logits=batched_logits) # [1, B*n] -> [B, n] sampled_ind = dist.sample((1, )) sampled_log_pi = dist.log_prob(sampled_ind) sampled_ind = sampled_ind.view(batch_size, -1) sampled_log_pi = sampled_log_pi.view(batch_size, -1) return sampled_ind, sampled_log_pi def _batched_index_select(self, t, dim, inds): expanded_ind = inds.unsqueeze(-1) out = t.gather(dim, expanded_ind) return out @staticmethod def _parallel_to_individual_network_sync(p_net, np_net, step): """Sync parameters from parallel version to indivisual version Args: p_net (ParallelNetwork): the parallel version of network np_net (list[Network|ParallelNetwork]): a list of the individual networks. Note that each individual network can also be an instance of ParallelNetwork. step (int): the replica contained in the individual network. For exmaple: - if the individual net is a plain network, step=1 - if the individual net is a parallel network, step = replica of the individual net """ split_num = len(np_net) for i in range(split_num): for ws, wt in zip(p_net.parameters(), np_net[i].parameters()): wt.data.copy_(ws[i * step:(i + 1) * step])
[docs] def get_uniform_prior_logpi(self): return self._log_pi_uniform_prior
[docs] def sync_net(self): MdqCriticNetwork._parallel_to_individual_network_sync( self._pre_encoding_parallel_net, self._pre_encoding_nets, self._num_critic_replicas) MdqCriticNetwork._parallel_to_individual_network_sync( self._post_encoding_parallel_net, self._post_encoding_nets, self._num_critic_replicas)