# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Some nest utils functions."""
import abc
import torch
import torch.nn as nn
from functools import reduce
import numpy as np
from typing import Callable
import alf
from . import nest
from .nest import get_field, map_structure
from alf.tensor_specs import TensorSpec
[docs]class NestCombiner(abc.ABC, nn.Module):
"""A base class for combining all elements in a nested structure."""
def __init__(self, name: str, batch_dims: int = 1):
"""
Args:
name: name of the combiner
batch_dims: number of batch dims (default 1). This argument is only
necessary for combiners that are not batch-dim invariant (combined
results depending on the definition of batch dims, e.g., outer
product).
"""
super().__init__()
self._name = name
self._batch_dims = batch_dims
@abc.abstractmethod
def _combine_flat(self, tensors):
"""Given a list of tensors flattened from the nest, this function defines
the combining method.
Args:
tensors (list[Tensor]): a flat list of tensors
Returns:
tensor (Tensor): the combined result
"""
pass
def __call__(self, nested):
"""Combine all elements according to the method defined in
``combine_flat``.
Args:
nested (nest): a nested structure; each element can be either a
``Tensor` or a `TensorSpec``.
Returns:
Tensor or TensorSpec: if ``Tensor``, the returned is the concatenated
result; otherwise it's the tensor spec of the result.
"""
flat = nest.flatten(nested)
assert len(flat) > 0, "The nest is empty!"
if isinstance(flat[0], TensorSpec):
tensors = nest.map_structure(
lambda spec: spec.zeros(outer_dims=(1, ) * self._batch_dims),
flat)
else:
tensors = flat
ret = self._combine_flat(tensors)
if isinstance(flat[0], TensorSpec):
return TensorSpec.from_tensor(ret, from_dim=self._batch_dims)
return ret
[docs]@alf.configurable
@alf.repr_wrapper
class NestConcat(NestCombiner):
def __init__(self, nest_mask=None, dim=-1, name="NestConcat"):
"""A combiner for selecting from the tensors in a nest and then
concatenating them along a specified axis. If nest_mask is None,
then all the tensors from the nest will be selected.
It assumes that all the selected tensors have the same tensor spec.
Can be used as a preprocessing combiner of a network.
Note that batch dimension is not considered for concat. This means that
dim=0 means the first dimension after batch dimension.
Args:
nest_mask (nest|None): nest structured mask indicating which of the
tensors in the nest to be selected or not, indicated by a
value of True/False (1/0). Note that the structure of the mask
should be the same as the nest of data to apply this operator on.
If is None, then all the tensors from the nest will be selected.
dim (int): the dim along which the tensors are concatenated
name (str):
"""
super(NestConcat, self).__init__(name)
self._nest_mask = nest_mask
self._flat_mask = nest.flatten(nest_mask) if nest_mask else nest_mask
self._dim = dim if dim < 0 else dim + 1
def _combine_flat(self, tensors):
if self._flat_mask is not None:
assert len(self._flat_mask) == len(tensors), (
"incompatible structures "
"between mask and data nest")
selected_tensors = []
for i, mask_value in enumerate(self._flat_mask):
if mask_value:
selected_tensors.append(tensors[i])
return torch.cat(selected_tensors, dim=self._dim)
else:
return torch.cat(tensors, dim=self._dim)
[docs] def make_parallel(self, n):
"""Create a ``NestConcat`` layer to handle parallel batch.
It is assumed that a parallel batch has shape [B, n, ...] and both the
batch dimension and replica dimension are not considered for concat.
Args:
n (int): the number of replicas.
Returns:
a ``NestConcat`` layer to handle parallel batch.
"""
return NestConcat(self._nest_mask, self._dim, "parallel_" + self._name)
[docs]@alf.configurable
class NestSum(NestCombiner):
def __init__(self, average=False, activation=None, name="NestSum"):
"""Add all tensors in a nest together. It assumes that all tensors have
the same tensor shape. Can be used as a preprocessing combiner of
a network.
Args:
average (bool): If True, the tensors are averaged instead of summed.
activation (Callable): activation function.
name (str):
"""
super(NestSum, self).__init__(name)
self._average = average
if activation is None:
activation = lambda x: x
self._activation = activation
def _combine_flat(self, tensors):
ret = sum(tensors)
if self._average:
ret *= 1 / float(len(tensors))
return self._activation(ret)
[docs] def make_parallel(self, n):
return NestSum(self._average, self._activation,
"parallel_" + self._name)
[docs]@alf.configurable
class NestMultiply(NestCombiner):
def __init__(self, activation=None, name="NestMultiply"):
"""Element-wise multiply all tensors in a nest. It assumes that all
tensors have the same shape. Can be used as a preprocessing combiner of
a network.
Args:
activation (Callable): optional activation function applied after
the multiplication.
name (str):
"""
super(NestMultiply, self).__init__(name)
if activation is None:
activation = lambda x: x
self._activation = activation
def _combine_flat(self, tensors):
ret = alf.utils.math_ops.mul_n(tensors)
return self._activation(ret)
[docs] def make_parallel(self, n):
return NestMultiply(self._activation, "parallel_" + self._name)
[docs]@alf.configurable
@alf.repr_wrapper
class NestOuterProduct(NestCombiner):
def __init__(self,
activation: Callable = None,
batch_dims: int = 1,
padding: bool = False,
name: str = "NestOuterProduct"):
"""Perform outer-product operations across a nested structure. Can be used
as a preprocessing combiner of a network.
Sometimes combining tensors using outer product might be more expressive
than concatenating, e.g., when one tensor is one-hot. See the discussions in
::
"STOCHASTIC NEURAL NETWORKS FOR HIERARCHICAL REINFORCEMENT LEARNING",
Florensa, et al., ICLR 2017, https://arxiv.org/pdf/1704.03012.pdf.
In this implementation, we also support padding 1s to the tensors before
doing the outer product, essentially combining outer product and
concatenation together in one combiner.
.. warning::
Due to outer product, this combiner might result in a very long
output vector. Make sure to do the calculation before using it.
Args:
activation: optional activation function applied after the outer product.
batch_dims: number of batch dims. Default to 1. If the total input dim
is ``N``, then the last ``N-batch_dims`` will be flattened for
outer product.
padding: if True, each tensor will be padded by 1 before performing
outer product. When this flag is enabled, essentially it has
the effect of concatenation of all tensors in the output tensor.
name: name of the combiner
"""
super(NestOuterProduct, self).__init__(name, batch_dims=batch_dims)
if activation is None:
activation = alf.layers.identity
self._activation = activation
self._padding = padding
def _combine_flat(self, tensors):
batch_shape = tensors[0].shape[:self._batch_dims]
for t in tensors:
assert batch_shape == t.shape[:self._batch_dims], (
"Different batch shapes %s vs. %s" %
(batch_shape, t.shape[:self._batch_dims]))
B = int(np.prod(batch_shape))
tensors = [t.reshape(B, -1) for t in tensors]
if self._padding:
tensors = [
torch.cat([t, torch.ones((B, 1), dtype=t.dtype)], dim=1)
for t in tensors
]
out = reduce(
lambda x, y: torch.einsum('bn,bm->bnm', x, y).reshape(B, -1),
tensors)
out = out.reshape(*batch_shape, -1)
return self._activation(out)
[docs] def make_parallel(self, n):
return NestOuterProduct(self._activation, self._batch_dims + 1,
self._padding, "parallel_" + self._name)
[docs]def stack_nests(nests, dim=0):
"""Stack tensors to a sequence.
All the nest should have same structure and shape. In the resulted nest,
each tensor has shape of :math:`[T,...]` and is the concat of all the
corresponding tensors in nests.
Args:
nests (list[nest]): list of nests with same structure and shape.
dim (int): dimension to insert. Has to be between 0 and the number of
dimensions of concatenated tensors (inclusive)
Returns:
a nest with same structure as ``nests[0]``.
"""
if len(nests) == 1:
return nest.map_structure(lambda tensor: tensor.unsqueeze(dim),
nests[0])
else:
return nest.map_structure(lambda *tensors: torch.stack(tensors, dim),
*nests)
[docs]def get_outer_rank(tensors, specs):
"""Compares tensors to specs to determine the number of batch dimensions.
For each tensor, it checks the dimensions with respect to specs and
returns the number of batch dimensions if all nested tensors and
specs agree with each other.
Args:
tensors (nested Tensors): Nested list/tuple/dict of Tensors.
specs (nested TensorSpecs): Nested list/tuple/dict of TensorSpecs,
describing the shape of unbatched tensors.
Returns:
int: The number of outer dimensions for all tensors (zero if all are
unbatched or empty).
Raises:
AssertionError: If the shape of Tensors are not compatible with specs,
a mix of batched and unbatched tensors are provided, or the tensors
are batched but have an incorrect number of outer dims.
"""
outer_ranks = []
def _get_outer_rank(tensor, spec):
outer_rank = len(tensor.shape) - len(spec.shape)
assert outer_rank >= 0
assert tensor.shape[outer_rank:] == spec.shape
outer_ranks.append(outer_rank)
nest.map_structure(_get_outer_rank, tensors, specs)
outer_rank = outer_ranks[0]
assert all([r == outer_rank
for r in outer_ranks]), ("Tensors have different "
"outer_ranks %s" % outer_ranks)
return outer_rank
[docs]def convert_device(nests, device=None):
"""Convert the device of the tensors in nests to the specified
or to the default device.
Args:
nests (nested Tensors): Nested list/tuple/dict of Tensors.
device (None|str): the target device, should either be `cuda` or `cpu`.
If None, then the default device will be used as the target device.
Returns:
nests (nested Tensors): Nested list/tuple/dict of Tensors after device
conversion.
Raises:
NotImplementedError if the target device is not one of
None, `cpu` or `cuda` when cuda is available, or AssertionError
if target device is `cuda` but cuda is unavailable.
"""
def _convert_cuda(tensor):
if tensor.device.type != 'cuda':
return tensor.cuda()
else:
return tensor
def _convert_cpu(tensor):
if tensor.device.type != 'cpu':
return tensor.cpu()
else:
return tensor
if device is None:
d = alf.get_default_device()
else:
d = device
if d == 'cpu':
return nest.map_structure(_convert_cpu, nests)
elif d == 'cuda':
assert torch.cuda.is_available(), "cuda is unavailable"
return nest.map_structure(_convert_cuda, nests)
else:
raise NotImplementedError("Unknown device %s" % d)
[docs]def grad(nested, objective, retain_graph=False):
"""Compute the gradients of an ``objective`` `w.r.t` each variable in
``nested``. It will simply call ``torch.autograd.grad`` after flattening the
nest, and then pack the flat list back to a structure like ``nested``.
Args:
nested (nest): a nest of variables that require grads.
objective (Tensor): a tensor whose gradients will be computed.
retain_graph (bool): if True, after autograd the computational graph
won't be freed
"""
return nest.pack_sequence_as(
nested,
list(
torch.autograd.grad(
objective, nest.flatten(nested), retain_graph=retain_graph)))
[docs]def zeros_like(nested):
"""Create a new nest with all zeros like the reference ``nested``.
Args:
nested (nested Tensor): a nested structure
Returns:
nested Tensor: a nest with all zeros
"""
return nest.map_structure(torch.zeros_like, nested)
[docs]def make_nested_module(nested, ignore_non_module_element=True):
"""Convert a nest of modules to nn.Module using nn.ModuleList or nn.ModuleDict.
The reason to use this function is that nest of Modules will not be trained
or checkpointed. We need to use nn.ModuleList or nn.ModuleDict to hold the
individual modules in the nest.
Args:
nested (nested nn.Module): a nest of nn.Module
ignore_non_module_element (bool): If True, will ignore the non-module element
and replace them with None. If False, will raise error if there are
any non-module elements.
Returns:
nn.Module
"""
if isinstance(nested, (tuple, list)):
module = torch.nn.ModuleList()
for m in nested:
module.append(make_nested_module(m))
elif nest.is_namedtuple(nested) or isinstance(nested, dict):
module = torch.nn.ModuleDict()
for field, value in nest.extract_fields_from_nest(nested):
module[field] = make_nested_module(value)
else:
module = nested
if not ignore_non_module_element:
assert isinstance(
nested,
torch.nn.Module), ("Unsupported type %s" % type(nested))
elif not isinstance(nested, torch.nn.Module):
module = None
return module
[docs]def get_nested_field(nested, nest_fields):
"""Get nested fields from a nest.
Example:
x = get_nested_field(nest, ('a.b', 'c'))
y = (get_field(nest, 'a.b')), get_field(nest, 'c'))
# y and x are same
Args:
nested (nest): a nested structure
nest_fields (nested str): nested strings. Each string indicates a path
to retrieve the value from ``nest``
Returns:
a nest with same structure as ``nest_fields``.
"""
return map_structure(lambda f: get_field(nested, f), nest_fields)