# Copyright (c) 2019 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Some basic layers."""
from absl import logging
import copy
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Callable, Dict, Iterable, Optional, Tuple, Union
import alf
from alf.initializers import variance_scaling_init
from alf.nest.utils import (get_nested_field, get_outer_rank, NestConcat,
NestMultiply, NestOuterProduct, NestSum)
from alf.nest import map_structure, get_field
from alf.tensor_specs import TensorSpec
from alf.utils import common
from alf.utils.math_ops import identity
from alf.utils.summary_utils import summarize_tensor_gradients
from alf.utils.tensor_utils import BatchSquash, tensor_extend_new_dim
from .norm_layers import BatchNorm1d, BatchNorm2d, prepare_rnn_batch_norm
from .norm_layers import ParamLayerNorm1d, ParamLayerNorm2d
[docs]def normalize_along_batch_dims(x, mean, variance, variance_epsilon):
"""Normalizes a tensor by ``mean`` and ``variance``, which are expected to have
the same tensor spec with the inner dims of ``x``.
Args:
x (Tensor): a tensor of (``[D1, D2, ..] + shape``), where ``D1``, ``D2``, ..
are arbitrary leading batch dims (can be empty).
mean (Tensor): a tensor of ``shape``
variance (Tensor): a tensor of ``shape``
variance_epsilon (float): A small float number to avoid dividing by 0.
Returns:
Normalized tensor.
"""
spec = TensorSpec.from_tensor(mean)
assert spec == TensorSpec.from_tensor(variance), \
"The specs of mean and variance must be equal!"
bs = BatchSquash(get_outer_rank(x, spec))
x = bs.flatten(x)
inv = torch.rsqrt(variance + variance_epsilon)
x = (x - mean.to(x.dtype)) * inv.to(x.dtype)
x = bs.unflatten(x)
return x
[docs]class ElementwiseLayerBase(nn.Module):
"""Base class for the layers of parameterless elementwise operations."""
[docs] def make_parallel(self, n: int):
"""Create a layer with same operation to handle parallel batch.
It is assumed that a parallel batch has shape [B, n, ...].
Args:
n (int): the number of replicas.
Returns:
a layer with same operation to handle parallel batch.
"""
assert len(list(self.parameters())) == 0
return self
[docs]class Identity(ElementwiseLayerBase):
"""A layer that simply returns its argument as result."""
def __init__(self):
super().__init__()
[docs] def forward(self, x):
return x
[docs]class Cast(ElementwiseLayerBase):
"""A layer that cast the dtype of the elements of the input tensor."""
def __init__(self, dtype=torch.float32):
"""
Args:
dtype (torch.dtype): desired type of the new tensor.
"""
super().__init__()
self._dtype = dtype
[docs] def forward(self, x):
return x.to(self._dtype)
[docs]class Transpose(nn.Module):
"""A layer that perform the transpose of channels.
Note that batch dimension is not considered for transpose. This means that
dim0=0 means the dimension after batch dimension.
"""
def __init__(self, dim0=0, dim1=1):
"""
Args:
dim0 (int): the first dimension to be transposed.
dim1 (int): the second dimension to be transposed
"""
super().__init__()
if dim0 >= 0:
dim0 += 1
self._dim0 = dim0
if dim1 >= 0:
dim1 += 1
self._dim1 = dim1
[docs] def forward(self, x):
return x.transpose(self._dim0, self._dim1)
[docs] def make_parallel(self, n: int):
"""Create a Transpose layer to handle parallel batch.
It is assumed that a parallel batch has shape [B, n, ...] and both the
batch dimension and replica dimension are not considered for transpose.
Args:
n (int): the number of replicas.
Returns:
a ``Transpose`` layer to handle parallel batch.
"""
return Transpose(self._dim0, self._dim1)
[docs]class Permute(nn.Module):
"""A layer that perform the permutation of channels."""
def __init__(self, *dims):
"""
Args:
*dims: The desired ordering of dimensions (not including batch dimension)
"""
super().__init__()
assert all([d >= 0 for d in dims
]), ("dims should be non-negative. Got %s" % str(dims))
dims = [1 + d for d in dims]
self._dims = [0] + dims
[docs] def forward(self, x):
return x.permute(*self._dims)
[docs] def make_parallel(self, n: int):
"""Create a Permute layer to handle parallel batch.
It is assumed that a parallel batch has shape [B, n, ...] and both the
batch dimension and replica dimension are not considered for permute.
Args:
n (int): the number of replicas.
Returns:
a ``Permute`` layer to handle parallel batch.
"""
return Permute(*self._dims)
[docs]@alf.configurable
class OneHot(nn.Module):
def __init__(self, num_classes):
super().__init__()
self._num_classes = num_classes
[docs] def forward(self, input):
return nn.functional.one_hot(
input, num_classes=self._num_classes).to(torch.float32)
[docs] def make_parallel(self, n: int):
return OneHot(self._num_classes)
[docs]@alf.configurable
class FixedDecodingLayer(nn.Module):
"""A layer that uses a set of fixed basis for decoding the inputs."""
def __init__(self,
input_size,
output_size,
basis_type="rbf",
sigma=1.,
tau=0.5):
"""
Args:
input_size (int): the size of input to be decoded, representing the
number of representation coefficients
output_size (int): the size of the decoded output
basis_type (str): the type of basis to be used for decoding
- "poly": polynomial basis using Vandermonde matrix
- "cheb": polynomial basis using Chebyshev polynomials
- "rbf": radial basis functions
- "haar": Haar wavelet basis
sigma (float): the bandwidth parameter used for RBF basis.
If None, a default value of 1. will be used.
tau (float): a factor for weighting the basis exponentially
according to the order (``n``) of the basis, i.e., ``tau**n```
"""
# get the argument list with vals
self._kwargs = copy.deepcopy(locals())
self._kwargs.pop('self')
self._kwargs.pop('__class__')
super(FixedDecodingLayer, self).__init__()
assert input_size > 0, "input_size should be at least one"
assert basis_type in {"poly", "cheb", "rbf", "haar"
}, ("the specified method "
"{} is not supported".format(basis_type))
self._B = nn.Linear(input_size, output_size, bias=False)
def _polyvander_matrix(n, D, tau=tau):
# non-square matrix [n, D + 1]
x = torch.linspace(-1, 1, n)
B = torch.as_tensor(
np.polynomial.polynomial.polyvander(x.cpu(), D))
# weight for encoding the preference to low-frequency basis
exp_factor = torch.arange(D + 1).float()
basis_weight = tau**exp_factor
return B * basis_weight
def _chebvander_matrix(n, D, tau=tau):
# non-square matrix [n, D + 1]
x = np.linspace(-1, 1, n)
B = torch.as_tensor(np.polynomial.chebyshev.chebvander(x, D))
# weight for encoding the preference to low-frequency basis
exp_factor = torch.arange(D + 1).float()
basis_weight = tau**exp_factor
return B * basis_weight
def _rbf_matrix(n, sigma=1.0):
# square matrix [n, n]
x = torch.linspace(-1, 1, n)
B = torch.empty(n, n)
for d in range(n):
B[:, d] = torch.exp(-(x - x[d])**2 / sigma)
return B
def _haar_matrix(n, tau=tau):
# square matrix [n, n]
def _is_power_of_two(x):
return (x & (x - 1)) == 0
# allow only size n to be the power of 2
assert _is_power_of_two(n), "n is required to be the power of 2"
def _get_haar_matrix(n):
if n > 2:
h = _get_haar_matrix(n // 2)
else:
return torch.Tensor([[1, 1], [1, -1]])
def _kron(A, B):
return torch.einsum("ab,cd->acbd", A, B).view(
A.size(0) * B.size(0),
A.size(1) * B.size(1))
# calculate upper haar part
h_n = _kron(h, torch.Tensor([[1], [1]]))
# calculate lower haar part
h_i = torch.sqrt(torch.Tensor([n / 2])) * _kron(
torch.eye(len(h)), torch.Tensor([[1], [-1]]))
# combine both parts
h = torch.cat((h_n, h_i), dim=1)
return h
B = _get_haar_matrix(n) / torch.sqrt(torch.Tensor([n]))
# weight for encoding the preference to low-frequency basis
exp_factor = torch.ceil(torch.log2(torch.arange(n).float() + 1))
basis_weight = tau**exp_factor
return B * basis_weight
if basis_type == "poly":
B = _polyvander_matrix(output_size, input_size - 1)
elif basis_type == "cheb":
B = _chebvander_matrix(output_size, input_size - 1)
elif basis_type == "rbf":
assert input_size == output_size
B = _rbf_matrix(input_size, sigma=sigma)
elif basis_type == "haar":
assert input_size == output_size
B = _haar_matrix(input_size)
# assign the constructed transformation matrix and set it to be non-trainable
self._B.weight.requires_grad = False
self._B.weight.copy_(B)
[docs] def forward(self, inputs):
return self._B(inputs)
@property
def weight(self):
return self._B.weight
[docs]@alf.configurable
@alf.repr_wrapper
class FC(nn.Module):
"""Fully connected layer."""
def __init__(self,
input_size,
output_size,
activation=identity,
use_bias=True,
use_bn=False,
use_ln=False,
bn_ctor=nn.BatchNorm1d,
kernel_initializer=None,
kernel_init_gain=1.0,
bias_init_value=0.0,
bias_initializer=None,
weight_opt_args: Optional[Dict] = None,
bias_opt_args: Optional[Dict] = None):
"""A fully connected layer that's also responsible for activation and
customized weights initialization. An auto gain calculation might depend
on the activation following the linear layer. Suggest using this wrapper
module instead of ``nn.Linear`` if you really care about weight std after
init.
Args:
input_size (int): input size
output_size (int): output size
activation (torch.nn.functional):
use_bias (bool): whether use bias
use_bn (bool): whether use batch normalization.
use_ln (bool): whether use layer normalization
bn_ctor (Callable): will be called as ``bn_ctor(num_features)`` to
create the BN layer.
kernel_initializer (Callable): initializer for the FC layer kernel.
If none is provided a ``variance_scaling_initializer`` with gain as
``kernel_init_gain`` will be used.
kernel_init_gain (float): a scaling factor (gain) applied to
the std of kernel init distribution. It will be ignored if
``kernel_initializer`` is not None.
bias_init_value (float): a constant for the initial bias value.
This is ignored if ``bias_initializer`` is provided.
bias_initializer (Callable): initializer for the bias parameter.
weight_opt_args: optimizer arguments for weight
bias_opt_args: optimizer arguments for bias
"""
# get the argument list with vals
self._kwargs = copy.deepcopy(locals())
self._kwargs.pop('self')
self._kwargs.pop('__class__')
super(FC, self).__init__()
self._input_size = input_size
self._output_size = output_size
self._activation = activation
self._weight = nn.Parameter(torch.Tensor(output_size, input_size))
# bias is useless if there is BN
use_bias = use_bias and not use_bn
if use_bias:
self._bias = nn.Parameter(torch.Tensor(output_size))
else:
self._bias = None
self._kernel_initializer = kernel_initializer
self._kernel_init_gain = kernel_init_gain
self._bias_init_value = bias_init_value
self._bias_initializer = bias_initializer
self._use_bias = use_bias
self._use_bn = use_bn
self._use_ln = use_ln
if use_bn:
self._bn = bn_ctor(output_size)
else:
self._bn = None
if use_ln:
self._ln = nn.LayerNorm(output_size)
else:
self._ln = None
self.reset_parameters()
if weight_opt_args:
self._weight.opt_args = weight_opt_args
if bias_opt_args and self._bias is not None:
self._bias.opt_args = bias_opt_args
@property
def input_size(self):
return self._input_size
@property
def output_size(self):
return self._output_size
[docs] def reset_parameters(self):
"""Initialize the parameters."""
if self._kernel_initializer is None:
variance_scaling_init(
self._weight.data,
gain=self._kernel_init_gain,
nonlinearity=self._activation)
else:
self._kernel_initializer(self._weight.data)
if self._use_bias:
if self._bias_initializer is not None:
self._bias_initializer(self._bias.data)
else:
nn.init.constant_(self._bias.data, self._bias_init_value)
if self._use_ln:
self._ln.reset_parameters()
if self._use_bn:
self._bn.reset_parameters()
[docs] def forward(self, inputs):
"""Forward computation.
Args:
inputs (Tensor): its shape should be ``[batch_size, input_size]`` or
``[batch_size, ..., input_size]``
Returns:
Tensor: with shape as ``inputs.shape[:-1] + (output_size,)``
"""
if inputs.dim() == 2 and self._use_bias:
y = torch.addmm(self._bias, inputs, self._weight.t())
else:
y = inputs.matmul(self._weight.t())
if self._use_bias:
y += self._bias
if self._use_ln:
if not self._use_bias:
self._ln.bias.data.zero_()
y = self._ln(y)
if self._use_bn:
y = self._bn(y)
return self._activation(y)
@property
def weight(self):
return self._weight
@property
def bias(self):
return self._bias
[docs] def make_parallel(self, n: int):
"""Create a ``ParallelFC`` using ``n`` replicas of ``self``.
The initialized layer parameters will be different.
"""
return ParallelFC(n=n, **self._kwargs)
[docs]@alf.configurable
class FCBatchEnsemble(FC):
r"""The BatchEnsemble for FC layer.
BatchEnsemble is proposed in `Wen et al. BatchEnsemble: An Alternative Approach
to Efficient Ensemble and Lifelong Learning <https://arxiv.org/abs/2002.06715>`_
In a nutshell, a tuple of vector :math:`(r_k, s_k)` is maintained for ensemble
member k in addition to the original FC weight matrix w. For input x, the
result for ensemble member k is calculated as :math:`(W \circ (s_k r_k^T)) x`.
This can be more efficiently calculated as :math:`(W (x \circ r_k)) \circ s_k`.
Note that for each sample in a batch, a random ensemble member will used for it
if ``ensemble_ids`` is not provided to ``forward()``.
"""
def __init__(self,
input_size,
output_size,
ensemble_size,
output_ensemble_ids=True,
activation=identity,
use_bias=True,
use_bn=False,
use_ln=False,
kernel_initializer=None,
kernel_init_gain=1.0,
bias_init_range=0.,
ensemble_group=0):
"""
Args:
input_size (int): input size
output_size (int): output size
ensemble_size (int): ensemble size
output_ensemble_ids (bool): If True, the forward() function will return
a tuple of (result, ensemble_ids). If False, the forward() function
will return result only.
activation (Callable): activation function
use_bias (bool): whether use bias
use_bn (bool): whether use batch normalization.
use_ln (bool): whether use layer normalization
kernel_initializer (Callable): initializer for the FC layer kernel.
If none is provided a ``variance_scaling_initializer`` with gain as
``kernel_init_gain`` will be used.
kernel_init_gain (float): a scaling factor (gain) applied to
the std of kernel init distribution. It will be ignored if
``kernel_initializer`` is not None.
bias_init_range (float): biases are initialized uniformly in
[-bias_init_range, bias_init_range]
ensemble_group (int): the extra attribute ``ensemble_group`` added
to ``self._r``, ``self._s``, and ``self._ensemble_bias``,
default value is 0.
For alf.optimizers whose ``parvi`` is not ``None``, all parameters
with the same ``ensemble_group`` will be updated by the
particle-based VI algorithm specified by ``parvi``, options are
[``svgd``, ``gfsf``],
* Stein Variational Gradient Descent (SVGD)
Liu, Qiang, and Dilin Wang. "Stein Variational Gradient Descent:
A General Purpose Bayesian Inference Algorithm." NIPS. 2016.
* Wasserstein Gradient Flow with Smoothed Functions (GFSF)
Liu, Chang, et al. "Understanding and accelerating particle-based
variational inference." ICML, 2019.
"""
nn.Module.__init__(self)
self._r = nn.Parameter(torch.Tensor(ensemble_size, input_size))
self._s = nn.Parameter(torch.Tensor(ensemble_size, output_size))
self._ensemble_bias = nn.Parameter(
torch.Tensor(ensemble_size, output_size))
assert isinstance(ensemble_group,
int), ("ensemble_group has to be an integer!")
self._r.ensemble_group = ensemble_group
self._s.ensemble_group = ensemble_group
self._ensemble_bias.ensemble_group = ensemble_group
self._use_ensemble_bias = use_bias
self._ensemble_size = ensemble_size
self._output_ensemble_ids = output_ensemble_ids
self._bias_init_range = bias_init_range
super().__init__(
input_size,
output_size,
activation=activation,
use_bias=False,
use_bn=use_bn,
use_ln=use_ln,
kernel_initializer=kernel_initializer,
kernel_init_gain=kernel_init_gain)
[docs] def reset_parameters(self):
"""Reinitialize parameters."""
super().reset_parameters()
# Both r and s are initialized to +1/-1 according to Appendix B
torch.randint(
2, size=self._r.shape, dtype=torch.float32, out=self._r.data)
torch.randint(
2, size=self._s.shape, dtype=torch.float32, out=self._s.data)
self._r.data.mul_(2)
self._r.data.sub_(1)
self._s.data.mul_(2)
self._s.data.sub_(1)
if self._use_ensemble_bias:
nn.init.uniform_(
self._ensemble_bias.data,
a=-self._bias_init_range,
b=self._bias_init_range)
[docs] def forward(self, inputs):
"""Forward computation.
Args:
inputs (Tensor|tuple): if a Tensor, its shape should be ``[batch_size, input_size]`` or
``[batch_size, ..., input_size]``. And a random ensemble id will be
generated for each sample in the batch. If a tuple, it should
contain two tensors. The first one is the data tensor with shape
``[batch_size, input_size]`` or ``[batch_size, ..., input_size]``.
The second one is ensemble_ids indicating which ensemble member each
sample should use. Its shape should be [batch_size], and all elements
should be in [0, ensemble_size).
Returns:
tuple if ``output_ensemble_ids`` is True,
- Tensor: with shape as ``inputs.shape[:-1] + (output_size,)``
- LongTensor: if enseble_ids is provided, this is same as ``ensemble_ids``,
otherwise a randomly generated ensemble_ids is returned
Tensor if ``output_ensemble_ids`` is False. The result of FC.
"""
if type(inputs) == tuple:
inputs, ensemble_ids = inputs
else:
ensemble_ids = torch.randint(
self._ensemble_size, size=(inputs.shape[0], ))
batch_size = inputs.shape[0]
output_size, input_size = self._weight.shape
r = self._r[ensemble_ids] # [batch_size, input_size]
s = self._s[ensemble_ids] # [batch_size, output_size]
if inputs.ndim > 2:
ones = [1] * (inputs.ndim - 2)
r = r.reshape(batch_size, *ones, input_size)
s = s.reshape(batch_size, *ones, output_size)
y = (inputs * r).matmul(self._weight.t())
y = y * s
if self._use_ensemble_bias:
bias = self._ensemble_bias[ensemble_ids]
if inputs.ndim > 2:
bias = bias.reshape(batch_size, *ones, output_size)
y += bias
if self._use_ln:
if not self._use_ensemble_bias:
self._ln.bias.data.zero_()
y = self._ln(y)
if self._use_bn:
if not self._use_ensemble_bias:
self._bn.bias.data.zero_()
y = self._bn(y)
y = self._activation(y)
if self._output_ensemble_ids:
return y, ensemble_ids
else:
return y
[docs]@alf.configurable
@alf.repr_wrapper
class ParallelFC(nn.Module):
"""Parallel FC layer."""
def __init__(self,
input_size,
output_size,
n,
activation=identity,
use_bias=True,
use_bn=False,
use_ln=False,
bn_ctor=nn.BatchNorm1d,
kernel_initializer=None,
kernel_init_gain=1.0,
bias_init_value=0.,
bias_initializer=None,
weight_opt_args: Optional[Dict] = None,
bias_opt_args: Optional[Dict] = None):
"""
It is equivalent to ``n`` separate FC layers with the same
``input_size`` and ``output_size``.
Args:
input_size (int): input size
output_size (int): output size
n (int): n independent ``FC`` layers
activation (torch.nn.functional):
use_bn (bool): whether use Batch Normalization.
use_ln (bool): whether use layer normalization
bn_ctor (Callable): will be called as ``bn_ctor(num_features)`` to
create the BN layer.
use_bias (bool): whether use bias
kernel_initializer (Callable): initializer for the FC layer kernel.
If none is provided a ``variance_scaling_initializer`` with gain
as ``kernel_init_gain`` will be used.
kernel_init_gain (float): a scaling factor (gain) applied to
the std of kernel init distribution. It will be ignored if
``kernel_initializer`` is not None.
bias_init_value (float): a constant for the initial bias value.
This is ignored if ``bias_initializer`` is provided.
bias_initializer (Callable): initializer for the bias parameter.
weight_opt_args: optimizer arguments for weight
bias_opt_args: optimizer arguments for bias
"""
super().__init__()
self._input_size = input_size
self._output_size = output_size
self._activation = activation
self._weight = nn.Parameter(torch.Tensor(n, output_size, input_size))
if use_bias:
self._bias = nn.Parameter(torch.Tensor(n, output_size))
else:
self._bias = None
self._n = n
self._kernel_initializer = kernel_initializer
self._kernel_init_gain = kernel_init_gain
self._bias_init_value = bias_init_value
self._bias_initializer = bias_initializer
self._use_bias = use_bias
self._use_bn = use_bn
self._use_ln = use_ln
if use_bn:
self._bn = bn_ctor(n * output_size)
else:
self._bn = None
if use_ln:
self._ln = nn.GroupNorm(n, n * output_size)
else:
self._ln = None
self.reset_parameters()
if weight_opt_args:
self._weight.opt_args = weight_opt_args
if bias_opt_args and self._bias is not None:
self._bias.opt_args = bias_opt_args
[docs] def reset_parameters(self):
for i in range(self._n):
if self._kernel_initializer is None:
variance_scaling_init(
self._weight.data[i],
gain=self._kernel_init_gain,
nonlinearity=self._activation)
else:
self._kernel_initializer(self._weight.data[i])
if self._use_bias:
if self._bias_initializer is not None:
for i in range(self._n):
self._bias_initializer(self._bias.data[i])
else:
nn.init.constant_(self._bias.data, self._bias_init_value)
if self._use_ln:
self._ln.reset_parameters()
if self._use_bn:
self._bn.reset_parameters()
[docs] def forward(self, inputs):
"""Forward
Args:
inputs (torch.Tensor): with shape ``[B, n, input_size]`` or ``[B, input_size]``
Returns:
torch.Tensor with shape ``[B, n, output_size]``
"""
n, k, l = self._weight.shape
if inputs.ndim == 2:
assert inputs.shape[1] == l, (
"inputs has wrong shape %s. Expecting (B, %d)" % (inputs.shape,
l))
inputs = inputs.unsqueeze(0).expand(n, *inputs.shape)
elif inputs.ndim == 3:
assert (inputs.shape[1] == n and inputs.shape[2] == l), (
"inputs has wrong shape %s. Expecting (B, %d, %d)" %
(inputs.shape, n, l))
inputs = inputs.transpose(0, 1) # [n, B, l]
else:
raise ValueError("Wrong inputs.ndim=%d" % inputs.ndim)
if self.bias is not None:
y = torch.baddbmm(
self._bias.unsqueeze(1), inputs,
self.weight.transpose(1, 2)) # [n, B, k]
else:
y = torch.bmm(inputs, self._weight.transpose(1, 2)) # [n, B, k]
y = y.transpose(0, 1) # [B, n, k]
if self._ln is not None:
if self._bias is None:
self._ln.bias.data.zero_()
y1 = y.reshape(-1, n * k)
y = self._ln(y1)
y = y1.view(-1, n, k)
if self._bn is not None:
if self._bias is None:
self._bn.bias.data.zero_()
y1 = y.reshape(-1, n * k)
y1 = self._bn(y1)
y = y1.view(-1, n, k)
return self._activation(y)
@property
def weight(self):
"""Get the weight Tensor.
Returns:
Tensor: with shape (n, output_size, input_size). ``weight[i]`` is
the weight for the i-th FC layer. ``weight[i]`` can be used for
``FC`` layer with the same ``input_size`` and ``output_size``
"""
return self._weight
@property
def bias(self):
"""Get the bias Tensor.
Returns:
Tensor: with shape (n, output_size). ``bias[i]`` is the bias for the
i-th FC layer. ``bias[i]`` can be used for ``FC`` layer with
the same ``input_size`` and ``output_size``
"""
return self._bias
[docs]@alf.configurable
class CompositionalFC(nn.Module):
"""Compositional FC layer."""
def __init__(self,
input_size,
output_size,
n,
activation=identity,
output_comp_weight=True,
use_bias=True,
use_bn=False,
use_ln=False,
kernel_initializer=None,
kernel_init_gain=1.0,
bias_init_value=0.0):
"""
It maintains a set of ``n`` FC parameters for learning. During forward
computation, it composes the set of parameters using weighted average
with the compositional weight provided as input and then performs the
FC computation, which is equivalent to combine the pre-activation output
from each of the ``n`` FC layers using the compositional weight, and
then apply normalization and activation.
Args:
input_size (int): input size
output_size (int): output size
n (int): the size of the paramster set
activation (torch.nn.functional):
output_comp_weight (bool): If True, the forward() function will
return a tuple of (result, comp_weight) for easy chaining of
multiple layers in the case when the same compsitional weight
is used. If False, the forward() function will return result
only.
use_bias (bool): whether use bias
use_bn (bool): whether use Batch Normalization.
use_ln (bool): whether use layer normalization
kernel_initializer (Callable): initializer for the FC layer kernel.
If none is provided a ``variance_scaling_initializer`` with gain
as ``kernel_init_gain`` will be used.
kernel_init_gain (float): a scaling factor (gain) applied to
the std of kernel init distribution. It will be ignored if
``kernel_initializer`` is not None.
bias_init_value (float): a constant
"""
super().__init__()
self._activation = activation
self._weight = nn.Parameter(torch.Tensor(n, output_size, input_size))
if use_bias:
self._bias = nn.Parameter(torch.Tensor(n, output_size))
else:
self._bias = None
self._kernel_initializer = kernel_initializer
self._kernel_init_gain = kernel_init_gain
self._bias_init_value = bias_init_value
self._output_comp_weight = output_comp_weight
self._use_bias = use_bias
self._use_bn = use_bn
self._use_ln = use_ln
self._n = n
if use_bn:
self._bn = nn.BatchNorm1d(output_size)
else:
self._bn = None
if use_ln:
self._ln = nn.LayerNorm(output_size)
else:
self._ln = None
self.reset_parameters()
[docs] def forward(self, inputs):
"""Forward
Args:
inputs (torch.Tensor|tuple): If a Tensor, its shape should be
``[B, input_size]``. If a tuple, it should contain two elements.
The first is a Tensor with the shape of ``[B, input_size]``, the
second is a compositional weight Tensor with the shape of ``[B, n]``
or None. If the compositional weight is not specified (i.e. when
inputs is not a tuple) or None, a uniform weight of one wil be used.
Returns:
torch.Tensor representing the final activation with shape
``[B, output_size]`` if ``output_comp_weight`` is False.
Otherwise, return a tuple consisted of the final activation and the
compositional weight used.
"""
if type(inputs) == tuple:
inputs, comp_weight = inputs
else:
comp_weight = None
n, k, l = self._weight.shape
if inputs.ndim == 2:
assert inputs.shape[1] == l, (
"inputs has wrong shape %s. Expecting (B, %d)" % (inputs.shape,
l))
inputs = inputs.unsqueeze(0).expand(n, *inputs.shape)
else:
raise ValueError("Wrong inputs.ndim=%d" % inputs.ndim)
if self.bias is not None:
y = torch.baddbmm(
self._bias.unsqueeze(1), inputs,
self.weight.transpose(1, 2)) # [n, B, k]
else:
y = torch.bmm(inputs, self._weight.transpose(1, 2)) # [n, B, k]
y = y.transpose(0, 1) # [B, n, k]
if comp_weight is not None:
assert comp_weight.ndim == 2, (
"Wrong comp_weight.ndim=%d" % comp_weight.ndim)
# [B, 1, n] x [B, n, k] -> [B, 1, k] -> [B, k]
y = torch.bmm(comp_weight.unsqueeze(1), y).squeeze(1)
else:
y = y.sum(dim=1)
if self._use_ln:
if not self._use_bias:
self._ln.bias.data.zero_()
y = self._ln(y)
if self._use_bn:
if not self._use_bias:
self._bn.bias.data.zero_()
y = self._bn(y)
y = self._activation(y)
if self._output_comp_weight:
return (y, comp_weight)
else:
return y
[docs] def reset_parameters(self):
"""Initialize the parameters."""
for i in range(self._n):
if self._kernel_initializer is None:
variance_scaling_init(
self._weight.data[i],
gain=self._kernel_init_gain,
nonlinearity=self._activation)
else:
self._kernel_initializer(self._weight.data[i])
if self._use_bias:
nn.init.constant_(self._bias.data, self._bias_init_value)
if self._use_ln:
self._ln.reset_parameters()
if self._use_bn:
self._bn.reset_parameters()
@property
def weight(self):
"""Get the weight Tensor.
Returns:
Tensor: with shape (n, output_size, input_size). ``weight[i]`` is
the weight for the i-th FC layer. ``weight[i]`` can be used for
``FC`` layer with the same ``input_size`` and ``output_size``
"""
return self._weight
@property
def bias(self):
"""Get the bias Tensor.
Returns:
Tensor: with shape (n, output_size). ``bias[i]`` is the bias for the
i-th FC layer. ``bias[i]`` can be used for ``FC`` layer with
the same ``input_size`` and ``output_size``
"""
return self._bias
[docs]@alf.configurable
class CausalConv1D(nn.Module):
"""1D (Dilated) Causal Convolution layer.
1D Dilated Causal Convolution is proposed in `Aaron et al. WaveNet:
A generative model for raw audio <https://arxiv.org/abs/1609.03499>`_
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
dilation=1,
hide_current=False,
activation=torch.relu_,
use_bias=None,
use_bn=False,
kernel_initializer=None,
kernel_init_gain=1.0,
bias_init_value=0.0):
"""A layer implementing the 1D (Dilated) Causal Convolution.
It is also responsible for activation and customized weights
initialization. An auto gain calculation might depend on the activation
following the causal conv1d layer.
Note that the main difference of causal conv v.s. standard conv is that
each temporal element in the convolutional output is causal w.r.t.
the temporal elements from input. For example, for a length ``L``
sequence ``x`` with the shape of ``[B, C, L]``, and
``y = causal_conv(x)``, where the shape of ``y`` is
``[B, C', L]``, by causal we mean ``y[..., l]`` only depends on
``X[..., :l]`` (i.e. the past), and there is no dependency on
``X[..., l:]`` (i.e. future) as in the standard non-causal
convolution.
This can implemented by using an asymmetric padding, which in effect
shift the input to the right (future) according to kernel size.
Args:
in_channels (int): channels of the input
out_channels (int): channels of the output
kernel_size (int): size of the kernel
dilation (int): controls the spacing between the kernel points.
Please refer to here for a visual illustration:
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
hide_current (bool): whether to hide the current by shifting the
input to the right (future) by one. This is typically needed
in the first layer of a causal conv net.
activation (torch.nn.functional): activation to be applied to output
use_bias (bool|None): whether use bias. If None, will use ``not use_bn``
use_bn (bool): whether use batch normalization
kernel_initializer (Callable): initializer for the conv layer kernel.
If None is provided a variance_scaling_initializer with gain as
``kernel_init_gain`` will be used.
kernel_init_gain (float): a scaling factor (gain) applied to the
std of kernel init distribution. It will be ignored if
``kernel_initializer`` is not None.
bias_init_value (float): a constant
"""
super(CausalConv1D, self).__init__()
if use_bias is None:
use_bias = not use_bn
self._activation = activation
# use F.pad for asymmetric padding
if hide_current:
assert dilation == 1, "the dilation should be 1 for hiding current"
asymmetric_padding = (kernel_size, -1)
else:
asymmetric_padding = ((kernel_size - 1) * dilation, 0)
self._pad = partial(
F.pad, pad=asymmetric_padding, mode='constant', value=0)
self._causal_conv1d = nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=dilation,
bias=use_bias)
self._kernel_initializer = kernel_initializer
self._kernel_init_gain = kernel_init_gain
self._bias_init_value = bias_init_value
self._use_bias = use_bias
if use_bn:
self._bn = nn.BatchNorm1d(out_channels)
else:
self._bn = None
self.reset_parameters()
[docs] def reset_parameters(self):
"""Initialize the parameters."""
if self._kernel_initializer is None:
variance_scaling_init(
self._causal_conv1d.weight.data,
gain=self._kernel_init_gain,
nonlinearity=self._activation)
else:
self._kernel_initializer(self._causal_conv1d.weight.data)
if self._use_bias:
nn.init.constant_(self._causal_conv1d.bias.data,
self._bias_init_value)
if self._bn is not None:
self._bn.reset_parameters()
[docs] def forward(self, x):
"""
Args:
x (tensor): input of the shape [B, C, L] where B is the batch size,
C denotes the number of input channels, and L is the length of
the signal.
Returns:
A tensor of the shape [B, C', L], where C' denotes the number of
output channels.
"""
y = self._causal_conv1d(self._pad(x))
if self._bn is not None:
y = self._bn(y)
return self._activation(y)
@property
def weight(self):
return self._causal_conv1d.weight
@property
def bias(self):
return self._causal_conv1d.bias
[docs]@alf.configurable
class Conv2D(nn.Module):
"""2D Convolution Layer."""
def __init__(self,
in_channels,
out_channels,
kernel_size,
activation=torch.relu_,
strides=1,
padding=0,
use_bias=None,
use_bn=False,
use_ln=False,
weight_opt_args: Optional[Dict] = None,
bn_ctor=nn.BatchNorm2d,
kernel_initializer=None,
kernel_init_gain=1.0,
bias_init_value=0.0):
"""A 2D Conv layer that's also responsible for activation and customized
weights initialization. An auto gain calculation might depend on the
activation following the conv layer. Suggest using this wrapper module
instead of ``nn.Conv2d`` if you really care about weight std after init.
Args:
in_channels (int): channels of the input image
out_channels (int): channels of the output image
kernel_size (int or tuple):
activation (torch.nn.functional):
strides (int or tuple):
padding (int or tuple):
use_bias (bool|None): whether use bias. If None, will use ``not use_bn``
use_bn (bool): whether use batch normalization
use_ln (bool): whether use layer normalization
weight_opt_args: optimizer arguments for weight (not for bias)
bn_ctor (Callable): will be called as ``bn_ctor(num_features)`` to
create the BN layer.
kernel_initializer (Callable): initializer for the conv layer kernel.
If None is provided a variance_scaling_initializer with gain as
``kernel_init_gain`` will be used.
kernel_init_gain (float): a scaling factor (gain) applied to the
std of kernel init distribution. It will be ignored if
``kernel_initializer`` is not None.
bias_init_value (float): a constant
"""
# get the argument list with vals
self._kwargs = copy.deepcopy(locals())
self._kwargs.pop('self')
self._kwargs.pop('__class__')
super(Conv2D, self).__init__()
if use_bias is None:
use_bias = not use_bn
self._activation = activation
self._conv2d = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=strides,
padding=padding,
bias=use_bias)
self._kernel_initializer = kernel_initializer
self._kernel_init_gain = kernel_init_gain
self._bias_init_value = bias_init_value
self._use_bias = use_bias
if use_bn:
self._bn = bn_ctor(out_channels)
else:
self._bn = None
if use_ln:
self._ln = nn.GroupNorm(1, out_channels)
else:
self._ln = None
if weight_opt_args is not None:
self._conv2d.weight.opt_args = weight_opt_args
self.reset_parameters()
[docs] def reset_parameters(self):
"""Initialize the parameters."""
if self._kernel_initializer is None:
variance_scaling_init(
self._conv2d.weight.data,
gain=self._kernel_init_gain,
nonlinearity=self._activation)
else:
self._kernel_initializer(self._conv2d.weight.data)
if self._use_bias:
nn.init.constant_(self._conv2d.bias.data, self._bias_init_value)
if self._bn is not None:
self._bn.reset_parameters()
if self._ln is not None:
self._ln.reset_parameters()
[docs] def forward(self, img):
y = self._conv2d(img)
if self._ln is not None:
y = self._ln(y)
if self._bn is not None:
y = self._bn(y)
return self._activation(y)
@property
def weight(self):
return self._conv2d.weight
@property
def bias(self):
return self._conv2d.bias
[docs] def make_parallel(self, n: int):
return ParallelConv2D(n=n, **self._kwargs)
[docs]@alf.configurable
class Conv2DBatchEnsemble(Conv2D):
r"""The BatchEnsemble for 2D Conv layer.
BatchEnsemble is proposed in `Wen et al. BatchEnsemble: An Alternative Approach
to Efficient Ensemble and Lifelong Learning <https://arxiv.org/abs/2002.06715>`_
In a nutshell, a tuple of vector :math:`(r_k, s_k)` is maintained for ensemble
member k in addition to the conv2d kernel W of shape ``[C_out, C_in, K_h, K_w]``.
For input x of shape ``[B, C, H, W]``, the result for ensemble member k is
calculated as :math:`(W \circ (s_k r_k^T).unsqueeze(-1).unsqueeze(-1)) * x`.
This can be more efficiently calculated as
:math:`(W*(x \circ r_k.unsqueeze(-1).unsqueeze(-1))) \circ s_k.unsqueeze(-1).unsqueeze(-1)`
Note that for each sample in a batch, a random ensemble member will used for it
if ``ensemble_ids`` is not provided to ``forward()``.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
ensemble_size,
output_ensemble_ids=True,
activation=torch.relu_,
strides=1,
padding=0,
use_bias=None,
use_bn=False,
kernel_initializer=None,
kernel_init_gain=1.0,
bias_init_range=0.0,
ensemble_group=0):
"""
Args:
in_channels (int): channels of the input image
out_channels (int): channels of the output image
kernel_size (int or tuple):
ensemble_size (int): ensemble size
output_ensemble_ids (bool): If True, the forward() function will return
a tuple of (result, ensemble_ids). If False, the forward() function
will return result only.
activation (torch.nn.functional):
strides (int or tuple):
padding (int or tuple):
use_bias (bool|None): whether use bias. If None, will use ``not use_bn``
use_bn (bool): whether use batch normalization
kernel_initializer (Callable): initializer for the conv layer kernel.
If None is provided a variance_scaling_initializer with gain as
``kernel_init_gain`` will be used.
kernel_init_gain (float): a scaling factor (gain) applied to the
std of kernel init distribution. It will be ignored if
``kernel_initializer`` is not None.
bias_init_range (float): biases are initialized uniformly in
[-bias_init_range, bias_init_range]
ensemble_group (int): the extra attribute ``ensemble_group`` added
to ``self._r``, ``self._s``, and ``self._ensemble_bias``,
default value is 0.
For alf.optimizers whose ``parvi`` is not ``None``, all parameters
with the same ``ensemble_group`` will be updated by the
particle-based VI algorithm specified by ``parvi``, options are
[``svgd``, ``gfsf``],
* Stein Variational Gradient Descent (SVGD)
Liu, Qiang, and Dilin Wang. "Stein Variational Gradient Descent:
A General Purpose Bayesian Inference Algorithm." NIPS. 2016.
* Wasserstein Gradient Flow with Smoothed Functions (GFSF)
Liu, Chang, et al. "Understanding and accelerating particle-based
variational inference." ICML, 2019.
"""
nn.Module.__init__(self)
self._r = nn.Parameter(torch.Tensor(ensemble_size, in_channels))
self._s = nn.Parameter(torch.Tensor(ensemble_size, out_channels))
self._ensemble_bias = nn.Parameter(
torch.Tensor(ensemble_size, out_channels))
assert isinstance(ensemble_group,
int), ("ensemble_group has to be an integer!")
self._r.ensemble_group = ensemble_group
self._s.ensemble_group = ensemble_group
self._ensemble_bias.ensemble_group = ensemble_group
self._use_ensemble_bias = use_bias
self._ensemble_size = ensemble_size
self._output_ensemble_ids = output_ensemble_ids
self._bias_init_range = bias_init_range
super().__init__(
in_channels,
out_channels,
kernel_size,
activation=activation,
use_bias=False,
use_bn=False,
kernel_initializer=kernel_initializer,
kernel_init_gain=kernel_init_gain)
[docs] def reset_parameters(self):
"""Reinitialize the parameters."""
super().reset_parameters()
# Both r and s are initialized to +1/-1 according to Appendix B
torch.randint(
2, size=self._r.shape, dtype=torch.float32, out=self._r.data)
torch.randint(
2, size=self._s.shape, dtype=torch.float32, out=self._s.data)
self._r.data.mul_(2)
self._r.data.sub_(1)
self._s.data.mul_(2)
self._s.data.sub_(1)
if self._use_ensemble_bias:
nn.init.uniform_(
self._ensemble_bias.data,
a=-self._bias_init_range,
b=self._bias_init_range)
[docs] def forward(self, inputs):
"""Forward computation.
Args:
inputs (Tensor|tuple): if a Tensor, its shape should be ``[B, C, H, W]``.
And a random ensemble id will be generated for each sample in the batch.
If a tuple, it should contain two tensors. The first one is the data
tensor with shape ``[B, C, H, W]``. The second one is ensemble_ids
indicating which ensemble member each sample should use. Its shape
should be [batch_size], and all elements should be in [0, ensemble_size).
Returns:
tuple if ``output_ensemble_ids`` is True,
- Tensor: with shape ``[B, C_out, H_out, W_out]``
- LongTensor: if enseble_ids is provided, this is same as ``ensemble_ids``,
otherwise a randomly generated ensemble_ids is returned
Tensor if ``output_ensemble_ids`` is False. The result of Conv2D.
"""
if type(inputs) == tuple:
inputs, ensemble_ids = inputs
else:
ensemble_ids = torch.randint(
self._ensemble_size, size=(inputs.shape[0], ))
batch_size = inputs.shape[0]
r = self._r[ensemble_ids].unsqueeze_(-1).unsqueeze_(
-1) # [B, in_channels, 1, 1]
s = self._s[ensemble_ids].unsqueeze_(-1).unsqueeze_(
-1) # [B, out_channels, 1, 1]
y = self._conv2d(inputs * r) * s
if self._use_ensemble_bias:
bias = self._ensemble_bias[ensemble_ids].unsqueeze_(-1).unsqueeze_(
-1)
y += bias
if self._bn is not None:
y = self._bn(y)
y = self._activation(y)
if self._output_ensemble_ids:
return y, ensemble_ids
else:
return y
[docs]@alf.configurable
class ParallelConv2D(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
n,
activation=torch.relu_,
strides=1,
padding=0,
use_bias=None,
use_bn=False,
use_ln=False,
weight_opt_args: Optional[Dict] = None,
bn_ctor=nn.BatchNorm2d,
kernel_initializer=None,
kernel_init_gain=1.0,
bias_init_value=0.0):
"""A parallel 2D Conv layer that can be used to perform n independent
2D convolutions in parallel.
It is equivalent to ``n`` separate ``Conv2D`` layers with the same
``in_channels`` and ``out_channels``.
Args:
in_channels (int): channels of the input image
out_channels (int): channels of the output image
kernel_size (int or tuple):
n (int): n independent ``Conv2D`` layers
activation (torch.nn.functional):
strides (int or tuple):
padding (int or tuple):
use_bias (bool|None): whether use bias. If None, will use ``not use_bn``
use_bn (bool): whether use batch normalization
use_ln (bool): whether use layer normalization
weight_opt_args: optimizer arguments for weight (not for bias)
bn_ctor (Callable): will be called as ``bn_ctor(num_features)`` to
create the BN layer.
kernel_initializer (Callable): initializer for the conv layer kernel.
If None is provided a ``variance_scaling_initializer`` with gain
as ``kernel_init_gain`` will be used.
kernel_init_gain (float): a scaling factor (gain) applied to the
std of kernel init distribution. It will be ignored if
``kernel_initializer`` is not None.
bias_init_value (float): a constant
"""
super(ParallelConv2D, self).__init__()
if use_bias is None:
use_bias = not use_bn
self._activation = activation
self._n = n
self._use_bias = use_bias
self._in_channels = in_channels
self._out_channels = out_channels
self._kernel_initializer = kernel_initializer
self._kernel_init_gain = kernel_init_gain
self._bias_init_value = bias_init_value
self._kernel_size = common.tuplify2d(kernel_size)
self._conv2d = nn.Conv2d(
in_channels * n,
out_channels * n,
kernel_size,
groups=n,
stride=strides,
padding=padding,
bias=use_bias)
if use_bn:
self._bn = bn_ctor(n * out_channels)
else:
self._bn = None
if use_ln:
self._ln = nn.GroupNorm(n, n * out_channels)
else:
self._ln = None
if weight_opt_args is not None:
self._conv2d.weight.opt_args = weight_opt_args
self.reset_parameters()
[docs] def reset_parameters(self):
for i in range(self._n):
if self._kernel_initializer is None:
variance_scaling_init(
self._conv2d.weight.data[i * self._out_channels:(i + 1) *
self._out_channels],
gain=self._kernel_init_gain,
nonlinearity=self._activation)
else:
self._kernel_initializer(
self._conv2d.weight.data[i * self._out_channels:(i + 1) *
self._out_channels])
if self._use_bias:
nn.init.constant_(self._conv2d.bias.data, self._bias_init_value)
if self._bn:
self._bn.reset_parameters()
if self._ln is not None:
self._ln.reset_parameters()
[docs] def forward(self, img):
"""Forward
Args:
img (torch.Tensor): with shape ``[B, C, H, W]``
or ``[B, n, C, H, W]``
where the meaning of the symbols are:
- ``B``: batch size
- ``n``: number of replicas
- ``C``: number of channels
- ``H``: image height
- ``W``: image width.
When the shape of img is ``[B, C, H, W]``, all the n 2D Conv
operations will take img as the same shared input.
When the shape of img is ``[B, n, C, H, W]``, each 2D Conv operator
will have its own input data by slicing img.
Returns:
torch.Tensor with shape ``[B, n, C', H', W']``
where the meaning of the symbols are:
- ``B``: batch
- ``n``: number of replicas
- ``C'``: number of output channels
- ``H'``: output height
- ``W'``: output width
"""
if img.ndim == 4:
# the shared input case
assert img.shape[1] == self._in_channels, (
"Input img has wrong shape %s. Expecting (B, %d, H, W)" %
(img.shape, self._in_channels))
img = img.unsqueeze(1).expand(img.shape[0], self._n,
*img.shape[1:])
elif img.ndim == 5:
# the non-shared case
assert (
img.shape[1] == self._n
and img.shape[2] == self._in_channels), (
"Input img has wrong shape %s. Expecting (B, %d, %d, H, W)"
% (img.shape, self._n, self._in_channels))
else:
raise ValueError("Wrong img.ndim=%d" % img.ndim)
# merge replica and channels
img = img.reshape(img.shape[0], img.shape[1] * img.shape[2],
*img.shape[3:])
res = self._conv2d(img)
if self._ln is not None:
res = self._ln(res)
if self._bn is not None:
res = self._bn(res)
# reshape back: [B, n*C', H', W'] -> [B, n, C', H', W']
res = res.reshape(res.shape[0], self._n, self._out_channels,
*res.shape[2:])
return self._activation(res)
@property
def weight(self):
# The reason that weight cannot pre-computed at __init__ is deepcopy will
# fail. deepcopy is needed to implement the copy for the container networks.
# [n*C', C, kernel_size, kernel_size]->[n, C', C, kernel_size, kernel_size]
return self._conv2d.weight.view(
self._n, self._out_channels, self._in_channels,
self._kernel_size[0], self._kernel_size[1])
@property
def bias(self):
if self._use_bias:
# The reason that weight cannot pre-computed at __init__ is deepcopy will
# fail. deepcopy is needed to implement the copy for the container networks.
# [n*C']->[n, C']
return self._conv2d.bias.view(self._n, self._out_channels)
else:
return None
[docs]@alf.configurable
class ConvTranspose2D(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
activation=torch.relu_,
strides=1,
padding=0,
output_padding=0,
use_bias=None,
use_bn=False,
bn_ctor=nn.BatchNorm2d,
kernel_initializer=None,
kernel_init_gain=1.0,
bias_init_value=0.0):
"""A 2D ConvTranspose layer that's also responsible for activation and
customized weights initialization. An auto gain calculation might depend
on the activation following the conv layer. Suggest using this wrapper
module instead of ``nn.ConvTranspose2d`` if you really care about weight std
after init.
Args:
in_channels (int): channels of the input image
out_channels (int): channels of the output image
kernel_size (int or tuple):
activation (torch.nn.functional):
strides (int or tuple):
padding (int or tuple):
output_padding (int or tuple): Additional size added to one side of
each dimension in the output shape. Default: 0. See pytorch
documentation for more detail.
use_bias (bool|None): If None, will use ``not use_bn``
use_bn (bool): whether use batch normalization
bn_ctor (Callable): will be called as ``bn_ctor(num_features)`` to
create the BN layer.
kernel_initializer (Callable): initializer for the conv_trans layer.
If None is provided a variance_scaling_initializer with gain as
``kernel_init_gain`` will be used.
kernel_init_gain (float): a scaling factor (gain) applied to the
std of kernel init distribution. It will be ignored if
``kernel_initializer`` is not None.
bias_init_value (float): a constant
"""
# get the argument list with vals
self._kwargs = copy.deepcopy(locals())
self._kwargs.pop('self')
self._kwargs.pop('__class__')
super(ConvTranspose2D, self).__init__()
if use_bias is None:
use_bias = not use_bn
self._activation = activation
self._conv_trans2d = nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size,
stride=strides,
padding=padding,
output_padding=output_padding,
bias=use_bias)
self._kernel_initializer = kernel_initializer
self._kernel_init_gain = kernel_init_gain
self._bias_init_value = bias_init_value
self._use_bias = use_bias
if use_bn:
self._bn = bn_ctor(out_channels)
else:
self._bn = None
[docs] def reset_parameters(self):
"""Initialize the parameters."""
if self._kernel_initializer is None:
variance_scaling_init(
self._conv_trans2d.weight.data,
gain=self._kernel_init_gain,
nonlinearity=self._activation,
transposed=True)
else:
self._kernel_initializer(self._conv_trans2d.weight.data)
if self._use_bias:
nn.init.constant_(self._conv_trans2d.bias.data,
self._bias_init_value)
if self._bn is not None:
self._bn.reset_parameters()
[docs] def forward(self, img):
y = self._conv_trans2d(img)
if self._bn is not None:
y = self._bn(y)
return self._activation(y)
@property
def weight(self):
return self._conv_trans2d.weight
@property
def bias(self):
return self._conv_trans2d.bias
[docs] def make_parallel(self, n: int):
return ParallelConvTranspose2D(n=n, **self._kwargs)
[docs]@alf.configurable
class ParallelConvTranspose2D(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
n,
activation=torch.relu_,
strides=1,
padding=0,
output_padding=0,
use_bias=None,
use_bn=False,
bn_ctor=nn.BatchNorm2d,
kernel_initializer=None,
kernel_init_gain=1.0,
bias_init_value=0.0):
"""A parallel ConvTranspose2D layer that can be used to perform n
independent 2D transposed convolutions in parallel.
Args:
in_channels (int): channels of the input image
out_channels (int): channels of the output image
kernel_size (int or tuple):
n (int): n independent ``ConvTranspose2D`` layers
activation (torch.nn.functional):
strides (int or tuple):
padding (int or tuple):
output_padding (int or tuple): Additional size added to one side of
each dimension in the output shape. Default: 0. See pytorch
documentation for more detail.
use_bias (bool|None): If None, will use ``not use_bn``
use_bn (bool):
bn_ctor (Callable): will be called as ``bn_ctor(num_features)`` to
create the BN layer.
kernel_initializer (Callable): initializer for the conv_trans layer.
If None is provided a ``variance_scaling_initializer`` with gain
as ``kernel_init_gain`` will be used.
kernel_init_gain (float): a scaling factor (gain) applied to the
std of kernel init distribution. It will be ignored if
``kernel_initializer`` is not None.
bias_init_value (float): a constant
"""
super(ParallelConvTranspose2D, self).__init__()
if use_bias is None:
use_bias = not use_bn
self._activation = activation
self._n = n
self._in_channels = in_channels
self._out_channels = out_channels
self._kernel_size = common.tuplify2d(kernel_size)
self._conv_trans2d = nn.ConvTranspose2d(
in_channels * n,
out_channels * n,
kernel_size,
groups=n,
stride=strides,
padding=padding,
output_padding=output_padding,
bias=use_bias)
for i in range(n):
if kernel_initializer is None:
variance_scaling_init(
self._conv_trans2d.weight.data[i * in_channels:(i + 1) *
in_channels],
gain=kernel_init_gain,
nonlinearity=self._activation)
else:
kernel_initializer(
self._conv_trans2d.weight.data[i * in_channels:(i + 1) *
in_channels])
# [n*C, C', kernel_size, kernel_size]->[n, C, C', kernel_size, kernel_size]
self._weight = self._conv_trans2d.weight.view(
self._n, self._in_channels, self._out_channels,
self._kernel_size[0], self._kernel_size[1])
if use_bias:
nn.init.constant_(self._conv_trans2d.bias.data, bias_init_value)
# [n*C]->[n, C]
self._bias = self._conv_trans2d.bias.view(self._n,
self._out_channels)
else:
self._bias = None
if use_bn:
self._bn = bn_ctor(n * out_channels)
else:
self._bn = None
[docs] def forward(self, img):
"""Forward
Args:
img (torch.Tensor): with shape ``[B, C, H, W]``
or ``[B, n, C, H, W]``
where the meaning of the symbols are:
- ``B``: batch size
- ``n``: number of replicas
- ``C``: number of channels
- ``H``: image height
- ``W``: image width.
When the shape of img is ``[B, C, H, W]``, all the n transposed 2D
Conv operations will take img as the same shared input.
When the shape of img is ``[B, n, C, H, W]``, each transposed 2D
Conv operator will have its own input data by slicing img.
Returns:
torch.Tensor with shape ``[B, n, C', H', W']``
where the meaning of the symbols are:
- ``B``: batch
- ``n``: number of replicas
- ``C'``: number of output channels
- ``H'``: output height
- ``W'``: output width
"""
if img.ndim == 4:
# the shared input case
assert img.shape[1] == self._in_channels, (
"Input img has wrong shape %s. Expecting (B, %d, H, W)" %
(img.shape, self._in_channels))
img = img.unsqueeze(1).expand(img.shape[0], self._n,
*img.shape[1:])
elif img.ndim == 5:
# the non-shared case
assert (
img.shape[1] == self._n
and img.shape[2] == self._in_channels), (
"Input img has wrong shape %s. Expecting (B, %d, %d, H, W)"
% (img.shape, self._n, self._in_channels))
else:
raise ValueError("Wrong img.ndim=%d" % img.ndim)
# merge replica and channels
img = img.reshape(img.shape[0], img.shape[1] * img.shape[2],
*img.shape[3:])
res = self._conv_trans2d(img)
if self._bn is not None:
res = self._bn(res)
# reshape back: [B, n*C', H', W'] -> [B, n, C', H', W']
res = res.reshape(res.shape[0], self._n, self._out_channels,
res.shape[2], res.shape[3])
return self._activation(res)
@property
def weight(self):
return self._weight
@property
def bias(self):
return self._bias
[docs]@alf.configurable
class ParamFC(nn.Module):
def __init__(self,
input_size,
output_size,
activation=torch.relu_,
use_bias=True,
use_ln=False,
n_groups=None,
kernel_initializer=None,
kernel_init_gain=1.0,
bias_init_value=0.0):
"""A fully connected layer that does not maintain its own weight and bias,
but accepts both from users. If the given parameter (weight and bias)
tensor has an extra batch dimension (first dimension), it performs
parallel FC operation.
Args:
input_size (int): input size
output_size (int): output size
activation (torch.nn.functional):
use_bias (bool): whether use bias
use_ln (bool): whether use layer normalization
n_groups (int): number of parallel groups, it is determined by the first
dimension of the input parameters when calling ``set_parameters`` if
``use_ln`` is False. If ``use_ln`` is True, ``n_groups`` must
be specified at initialization and will be fixed, all input parameters
will have to be consistent with it.
kernel_initializer (Callable): initializer for the FC layer kernel.
If none is provided a ``variance_scaling_initializer`` with gain as
``kernel_init_gain`` will be used.
kernel_init_gain (float): a scaling factor (gain) applied to
the std of kernel init distribution. It will be ignored if
``kernel_initializer`` is not None.
bias_init_value (float): a constant
"""
super(ParamFC, self).__init__()
self._input_size = input_size
self._output_size = output_size
self._activation = activation
self._use_bias = use_bias
self._use_ln = use_ln
self._kernel_initializer = kernel_initializer
self._kernel_init_gain = kernel_init_gain
self._bias_init_value = bias_init_value
self._weight_length = output_size * input_size
if use_bias:
self._bias_length = output_size
else:
self._bias_length = 0
self._bias = None
if use_ln:
assert n_groups is not None, (
"n_groups has to be specified if use_ln")
self._ln = ParamLayerNorm1d(n_groups, output_size)
self._n_groups = n_groups
else:
n_groups = 1
self._param_length = None
self.set_parameters(torch.randn(n_groups, self.param_length))
@property
def weight(self):
"""Get stored weight tensor or batch of weight tensors."""
return self._weight
@property
def bias(self):
"""Get stored bias tensor or batch of bias tensors."""
return self._bias
@property
def weight_length(self):
"""Get the n_element of a single weight tensor. """
return self._weight_length
@property
def bias_length(self):
"""Get the n_element of a single bias tensor. """
return self._bias_length
@property
def param_length(self):
"""Get total number of parameters for all layers. """
if self._param_length is None:
length = self.weight_length
if self._use_bias:
length += self.bias_length
if self._use_ln:
length += self._ln.param_length
self._param_length = length
return self._param_length
[docs] def set_parameters(self, theta, reinitialize=False):
"""Distribute parameters to corresponding parameters.
Args:
theta (torch.Tensor): with shape ``[D] (groups=1)``
or ``[B, D] (groups=B)``
where the meaning of the symbols are:
- ``B``: batch size
- ``D``: length of parameters, should be self.param_length
When the shape of inputs is ``[D]``, it will be unsqueezed
to ``[1, D]``.
reinitialize (bool): whether to reinitialize parameters of
each layer.
"""
if theta.ndim == 1:
theta = theta.unsqueeze(0)
assert (theta.ndim == 2 and theta.shape[1] == self.param_length), (
"Input theta has wrong shape %s. Expecting shape (, %d)" %
(theta.shape, self.param_length))
if self._use_ln:
assert theta.shape[0] == self._n_groups, (
"the input has wrong n_groups. Expecting n_groups %d" %
self._n_groups)
else:
self._n_groups = theta.shape[0]
weight = theta[:, :self.weight_length]
self._set_weight(weight, reinitialize=reinitialize)
pos = self.weight_length
if self._use_bias:
bias = theta[:, pos:pos + self.bias_length]
self._set_bias(bias, reinitialize=reinitialize)
pos = pos + self.bias_length
if self._use_ln:
norm_theta = theta[:, pos:]
self._ln.set_parameters(norm_theta, reinitialize=reinitialize)
def _set_weight(self, weight, reinitialize=False):
"""Store a weight tensor or batch of weight tensors.
Args:
weight (torch.Tensor): with shape ``[B, D]``
where the mining of the symbols are:
- ``B``: batch size
- ``D``: length of weight vector, should be self._weight_length
reinitialize (bool): whether to reinitialize self._weight
"""
weight = weight.view(self._n_groups, self._output_size,
self._input_size)
if reinitialize:
for i in range(self._n_groups):
if self._kernel_initializer is None:
variance_scaling_init(
weight[i],
gain=self._kernel_init_gain,
nonlinearity=self._activation)
else:
self._kernel_initializer(weight[i])
self._weight = weight
def _set_bias(self, bias, reinitialize=False):
"""Store a bias tensor or batch of bias tensors.
Args:
bias (torch.Tensor): with shape ``[B, D]``
where the mining of the symbols are:
- ``B``: batch size
- ``D``: length of bias vector, should be self._bias_length
reinitialize (bool): whether to reinitialize self._bias
"""
if reinitialize:
if self._use_bias:
nn.init.constant_(bias, self._bias_init_value)
self._bias = bias # [n, bias_length]
[docs] def forward(self, inputs):
"""Forward
Args:
inputs (torch.Tensor): with shape ``[B, D] (groups=1)``
or ``[B, n, D] (groups=n)``
where the meaning of the symbols are:
- B: batch size
- n: number of replicas
- D: input dimension
When the shape of inputs is ``[B, D]``, all the n linear
operations will take inputs as the same shared inputs.
When the shape of inputs is ``[B, n, D]``, each linear operator
will have its own input data by slicing inputs.
Returns:
torch.Tensor: with shape ``[B, n, D]`` or ``[B, D]``
where the meaning of the symbols are:
- B: batch
- n: number of replicas
- D: output dimension
"""
if inputs.ndim == 2:
# case 1: non-parallel inputs
assert inputs.shape[1] == self._input_size, (
"Input inputs has wrong shape %s. Expecting (B, %d)" %
(inputs.shape, self._input_size))
inputs = inputs.unsqueeze(0).expand(self._n_groups, *inputs.shape)
elif inputs.ndim == 3:
# case 2: parallel inputs
assert (
inputs.shape[1] == self._n_groups
and inputs.shape[2] == self._input_size), (
"Input inputs has wrong shape %s. Expecting (B, %d, %d)" %
(inputs.shape, self._n_groups, self._input_size))
inputs = inputs.transpose(0, 1) # [n, B, D]
else:
raise ValueError("Wrong inputs.ndim=%d" % inputs.ndim)
if self._bias is not None:
res = torch.baddbmm(
self._bias.unsqueeze(1), inputs, self._weight.transpose(1, 2))
else:
res = torch.bmm(inputs, self._weight.transpose(1, 2))
res = res.transpose(0, 1) # [B, n, D]
if self._use_ln:
# squeeze is taken care of in self._ln
res = self._ln(res)
else:
res = res.squeeze(1) # [B, D] if n=1
return self._activation(res)
[docs]@alf.configurable
class ParamConv2D(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
activation=torch.relu_,
strides=1,
pooling_kernel=None,
padding=0,
use_bias=False,
use_ln=False,
n_groups=None,
kernel_initializer=None,
kernel_init_gain=1.0,
bias_init_value=0.0):
"""A 2D conv layer that does not maintain its own weight and bias,
but accepts both from users. If the given parameter (weight and bias)
tensor has an extra batch dimension (first dimension), it performs
parallel FC operation.
Args:
in_channels (int): channels of the input image
out_channels (int): channels of the output image
kernel_size (int or tuple):
activation (torch.nn.functional):
strides (int or tuple):
pooling_kernel (int or tuple):
padding (int or tuple):
use_bias (bool): whether use bias.
use_ln (bool): whether use layer normalization
n_groups (int): number of parallel groups, it is determined by the first
dimension of the input parameters when calling ``set_parameters`` if
``use_ln`` is False. If ``use_ln`` is True, ``n_groups`` must
be specified at initialization and will be fixed, all input parameters
will have to be consistent with it.
kernel_initializer (Callable): initializer for the conv layer kernel.
If None is provided a variance_scaling_initializer with gain as
``kernel_init_gain`` will be used.
kernel_init_gain (float): a scaling factor (gain) applied to the
std of kernel init distribution. It will be ignored if
``kernel_initializer`` is not None.
bias_init_value (float): a constant
"""
super(ParamConv2D, self).__init__()
self._in_channels = in_channels
self._out_channels = out_channels
self._activation = activation
self._kernel_size = common.tuplify2d(kernel_size)
self._kH, self._kW = self._kernel_size
self._strides = strides
self._pooling_kernel = pooling_kernel
self._padding = padding
use_bias = use_bias
self._use_bias = use_bias
self._use_ln = use_ln
self._n_groups = n_groups
self._kernel_initializer = kernel_initializer
self._kernel_init_gain = kernel_init_gain
self._bias_init_value = bias_init_value
self._weight_length = out_channels * in_channels * self._kH * self._kW
if use_bias:
self._bias_length = out_channels
else:
self._bias_length = 0
self._bias = None
if use_ln:
assert n_groups is not None, (
"n_groups has to be specified if use_ln")
self._ln = ParamLayerNorm2d(n_groups, out_channels)
self._n_groups = n_groups
else:
n_groups = 1
self._param_length = None
self.set_parameters(torch.randn(n_groups, self.param_length))
@property
def weight(self):
"""Get stored weight tensor or batch of weight tensors."""
return self._weight
@property
def bias(self):
"""Get stored bias tensor or batch of bias tensors."""
return self._bias
@property
def weight_length(self):
"""Get the n_element of a single weight tensor. """
return self._weight_length
@property
def bias_length(self):
"""Get the n_element of a single bias tensor. """
return self._bias_length
@property
def param_length(self):
"""Get total number of parameters for all layers. """
if self._param_length is None:
length = self.weight_length
if self._use_bias:
length += self.bias_length
if self._use_ln:
length += self._ln.param_length
self._param_length = length
return self._param_length
[docs] def set_parameters(self, theta, reinitialize=False):
"""Distribute parameters to corresponding parameters.
Args:
theta (torch.Tensor): with shape ``[D] (groups=1)``
or ``[B, D] (groups=B)``
where the meaning of the symbols are:
- ``B``: batch size
- ``D``: length of parameters, should be self.param_length
When the shape of inputs is ``[D]``, it will be unsqueezed
to ``[1, D]``.
reinitialize (bool): whether to reinitialize parameters of
each layer.
"""
if theta.ndim == 1:
theta = theta.unsqueeze(0)
assert (theta.ndim == 2 and theta.shape[1] == self.param_length), (
"Input theta has wrong shape %s. Expecting shape (, %d)" %
(theta.shape, self.param_length))
if self._use_ln:
assert theta.shape[0] == self._n_groups, (
"the input has wrong n_groups. Expecting n_groups %d" %
self._n_groups)
else:
self._n_groups = theta.shape[0]
weight = theta[:, :self.weight_length]
self._set_weight(weight, reinitialize=reinitialize)
pos = self.weight_length
if self._use_bias:
bias = theta[:, pos:pos + self.bias_length]
self._set_bias(bias, reinitialize=reinitialize)
pos = pos + self.bias_length
if self._use_ln:
norm_theta = theta[:, pos:]
self._ln.set_parameters(norm_theta, reinitialize=reinitialize)
def _set_weight(self, weight, reinitialize=False):
"""Store a weight tensor or batch of weight tensors.
Args:
weight (torch.Tensor): with shape ``[B, D]``
where the mining of the symbols are:
- ``B``: batch size
- ``D``: length of weight vector, should be self._weight_length
reinitialize (bool): whether to reinitialize self._weight
"""
if weight.shape[0] == 1:
# non-parallel weight
weight = weight.view(self._out_channels, self._in_channels,
self._kH, self._kW)
else:
# parallel weight
weight = weight.view(self._n_groups, self._out_channels,
self._in_channels, self._kH, self._kW)
weight = weight.reshape(self._n_groups * self._out_channels,
self._in_channels, self._kH, self._kW)
if reinitialize:
for i in range(self._n_groups):
if self._kernel_initializer is None:
variance_scaling_init(
weight[i * self._out_channels:(i + 1) *
self._out_channels],
gain=self._kernel_init_gain,
nonlinearity=self._activation)
else:
self._kernel_initializer(
weight[i * self._out_channels:(i + 1) *
self._out_channels])
self._weight = weight
def _set_bias(self, bias, reinitialize=False):
"""Store a bias tensor or batch of bias tensors.
Args:
bias (torch.Tensor): with shape ``[B, D]``
where the mining of the symbols are:
- ``B``: batch size
- ``D``: length of bias vector, should be self._bias_length
reinitialize (bool): whether to reinitialize self._bias
"""
if reinitialize:
if self._use_bias:
nn.init.constant_(bias, self._bias_init_value)
self._bias = bias.reshape(-1)
[docs] def forward(self, img, keep_group_dim=True):
"""Forward
Args:
img (torch.Tensor): with shape ``[B, C, H, W] (groups=1)``
or ``[B, n, C, H, W] (groups=n)``
where the meaning of the symbols are:
- ``B``: batch size
- ``n``: number of replicas
- ``C``: number of channels
- ``H``: image height
- ``W``: image width.
When the shape of img is ``[B, C, H, W]``, all the n 2D Conv
operations will take img as the same shared input.
When the shape of img is ``[B, n, C, H, W]``, each 2D Conv operator
will have its own input data by slicing img.
Returns:
torch.Tensor with shape ``[B, n, C', H', W']`` if ``keep_group_dim``
otherwise with shape ``[B, n*C', H', W']``,
where the meaning of the symbols are:
- ``B``: batch
- ``n``: number of replicas
- ``C'``: number of output channels
- ``H'``: output height
- ``W'``: output width
"""
if self._n_groups == 1:
# non-parallel layer
assert (img.ndim == 4 and img.shape[1] == self._in_channels), (
"Input img has wrong shape %s. Expecting (B, %d, H, W)" %
(img.shape, self._in_channels))
else:
# parallel layer
if img.ndim == 4:
if img.shape[1] == self._in_channels:
# case 1: non-parallel input
img = img.repeat(1, self._n_groups, 1, 1)
else:
# case 2: parallel input
assert img.shape[1] == self._n_groups * self._in_channels, (
"Input img has wrong shape %s. Expecting (B, %d, H, W) or (B, %d, H, W)"
% (img.shape, self._in_channels,
self._n_groups * self._in_channels))
elif img.ndim == 5:
# case 3: parallel input with unmerged group dim
assert (
img.shape[1] == self._n_groups
and img.shape[2] == self._in_channels
), ("Input img has wrong shape %s. Expecting (B, %d, %d, H, W)"
% (img.shape, self._n_groups, self._in_channels))
# merge group and channel dim
img = img.reshape(img.shape[0], img.shape[1] * img.shape[2],
*img.shape[3:])
else:
raise ValueError("Wrong img.ndim=%d" % img.ndim)
res = F.conv2d(
img,
self._weight,
bias=self._bias,
stride=self._strides,
padding=self._padding,
groups=self._n_groups)
if self._use_ln:
res = self._ln(res, keep_group_dim=False)
res = self._activation(res)
if self._pooling_kernel is not None:
res = F.max_pool2d(res, self._pooling_kernel)
if self._n_groups > 1 and keep_group_dim:
# reshape back: [B, n*C', H', W'] -> [B, n, C', H', W']
res = res.reshape(res.shape[0], self._n_groups, self._out_channels,
res.shape[2], res.shape[3])
return res
[docs]@alf.configurable
class Reshape(nn.Module):
def __init__(self, *shape):
"""A layer for reshape the tensor.
The result of this layer is a tensor reshaped to ``(B, *shape)`` where
``B`` is ``x.shape[0]``
Args:
shape (tuple of ints|int...): desired shape not including the batch dimension.
"""
super().__init__()
if len(shape) == 1:
if isinstance(shape[0], Iterable):
shape = tuple(shape[0])
self._shape = shape
[docs] def forward(self, x):
return x.reshape(x.shape[0], *self._shape)
[docs] def make_parallel(self, n: int):
return Reshape((n, ) + self._shape)
def _conv_transpose_2d(in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
bias=True):
# need output_padding so that output_size is stride * input_size
# See https://pytorch.org/docs/stable/nn.html#torch.nn.ConvTranspose2d
output_padding = stride + 2 * padding - kernel_size
return nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
bias=bias)
[docs]@alf.configurable(whitelist=[
'with_batch_normalization', 'bn_ctor', 'weight_opt_args', 'activation'
])
class ResidueBlock(nn.Module):
"""The ResidueBlock for ResNet.
This is the residual block used in ResNet-18 and ResNet-34 of the original
ResNet paper `Deep residual learning for image recognition
<https://arxiv.org/abs/1512.03385>`_.
Compared to BottleneckBlock, it has one less conv layer.
"""
def __init__(self,
in_channels: int,
channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]],
transpose: bool = False,
activation: nn.Module = nn.ReLU(inplace=True),
with_batch_normalization: bool = True,
weight_opt_args: Optional[Dict] = None,
bn_ctor: Callable[[int], nn.Module] = nn.BatchNorm2d):
"""
Args:
in_channels: the number of channels of input
kernel_size: the kernel size of middle layer at main path
filters: the number of filters of the two conv layers at main path
stride: stride for this block.
transpose: whether use ``Conv2D`` or ``Conv2DTranspose``.
If two ``ResidueBlock`` layers ``L`` and ``LT`` are constructed
with the same arguments except ``transpose``, it is guaranteed that
``LT(L(x)).shape == x.shape`` if ``x.shape[-2:]`` can be divided
by ``stride``.
activation: activation function.
with_batch_normalization: whether to include batch normalization.
Note that standard ResNet uses batch normalization.
weight_opt_args: optimizer arguments for weights (not for bias)
bn_ctor: will be called as ``bn_ctor(num_features)`` to
create the BN layer.
"""
super().__init__()
conv_fn = _conv_transpose_2d if transpose else nn.Conv2d
bias = not with_batch_normalization
self._activation = activation
padding = (kernel_size - 1) // 2
conv1 = conv_fn(
in_channels,
channels,
kernel_size,
stride,
padding=padding,
bias=bias)
conv2 = conv_fn(
channels, channels, kernel_size, padding=padding, bias=bias)
nn.init.kaiming_normal_(conv1.weight.data)
nn.init.kaiming_normal_(conv2.weight.data)
if weight_opt_args is not None:
conv1.weight.opt_args = weight_opt_args
conv2.weight.opt_args = weight_opt_args
if stride != 1 or in_channels != channels:
s = conv_fn(in_channels, channels, 1, stride, bias=bias)
nn.init.kaiming_normal_(s.weight.data)
if bias:
nn.init.zeros_(s.bias.data)
if with_batch_normalization:
shortcut_layers = nn.Sequential(s, bn_ctor(channels))
else:
shortcut_layers = s
if weight_opt_args is not None:
s.weight.opt_args = weight_opt_args
else:
shortcut_layers = None
if with_batch_normalization:
bn1 = bn_ctor(channels)
if isinstance(bn1, BatchNorm2d):
# When alf.layers.BatchNorm2d is used, it may be configured
# as fixed_weight_norm=True. That is reasonable if it is followed
# by conv+bn, since the result is invariant to its overall scale
# However, bn2 is followed by a sum with shortcut. The result
# is not invariant to its scale. So we explicitly set
# fixed_weight_norm=False
bn2 = bn_ctor(channels, fixed_weight_norm=False)
else:
bn2 = bn_ctor(channels)
core_layers = nn.Sequential(conv1, bn1, activation, conv2, bn2)
else:
core_layers = nn.Sequential(conv1, activation, conv2)
self._core_layers = core_layers
self._shortcut_layers = shortcut_layers
[docs] def forward(self, inputs: torch.Tensor) -> torch.Tensor:
core = self._core_layers(inputs)
if self._shortcut_layers:
shortcut = self._shortcut_layers(inputs)
else:
shortcut = inputs
return self._activation(core + shortcut)
[docs]@alf.configurable(whitelist=['v1_5', 'with_batch_normalization', 'bn_ctor'])
class BottleneckBlock(nn.Module):
"""Bottleneck block for ResNet.
We allow two slightly different architectures:
* v1: Placing the stride at the first 1x1 convolution as described in the
original ResNet paper `Deep residual learning for image recognition
<https://arxiv.org/abs/1512.03385>`_.
* v1.5: Placing the stride for downsampling at 3x3 convolution. This variant
is also known as ResNet V1.5 and improves accuracy according to
`<https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch>`_.
TODO:
1. ResNet-D in `Bag of Tricks for Image Classification with Convolutional Neural Networks
<https://openaccess.thecvf.com/content_CVPR_2019/papers/He_Bag_of_Tricks_for_Image_Classification_with_Convolutional_Neural_Networks_CVPR_2019_paper.pdf>`_
Note: v1_5 is the ResNet-B in the above paper.
2. Squeeze-and-Excitation (SE) in `Squeeze-and-Excitation Networks
<https://openaccess.thecvf.com/content_cvpr_2018/papers/Hu_Squeeze-and-Excitation_Networks_CVPR_2018_paper.pdf>`_
SE is also shown to be useful in
`Revisiting ResNets: Improved Training and Scaling Strategies <https://arxiv.org/abs/2103.07579>`_
"""
def __init__(self,
in_channels,
kernel_size,
filters,
stride,
transpose=False,
v1_5=True,
with_batch_normalization=True,
bn_ctor=nn.BatchNorm2d):
"""
Args:
kernel_size (int): the kernel size of middle layer at main path
filters (int): the filters of 3 layer at main path
stride (int): stride for this block.
transpose (bool): a bool indicate using ``Conv2D`` or ``Conv2DTranspose``.
If two BottleneckBlock layers ``L`` and ``LT`` are constructed
with the same arguments except ``transpose``, it is guaranteed that
``LT(L(x)).shape == x.shape`` if ``x.shape[-2:]`` can be divided
by ``stride``.
v1_5 (bool): whether to use the ResNet V1.5 structure
with_batch_normalization (bool): whether to include batch normalization.
Note that standard ResNet uses batch normalization.
bn_ctor (Callable): will be called as ``bn_ctor(num_features)`` to
create the BN layer.
"""
super().__init__()
filters1, filters2, filters3 = filters
conv_fn = _conv_transpose_2d if transpose else nn.Conv2d
bias = not with_batch_normalization
padding = (kernel_size - 1) // 2
if v1_5:
a = conv_fn(in_channels, filters1, 1, bias=bias)
b = conv_fn(
filters1, filters2, kernel_size, stride, padding, bias=bias)
else:
a = conv_fn(in_channels, filters1, 1, stride, bias=bias)
b = conv_fn(filters1, filters2, kernel_size, 1, padding, bias=bias)
c = conv_fn(filters2, filters3, 1, bias=bias)
nn.init.kaiming_normal_(a.weight.data)
nn.init.kaiming_normal_(b.weight.data)
nn.init.kaiming_normal_(c.weight.data)
if bias:
nn.init.zeros_(a.bias.data)
nn.init.zeros_(b.bias.data)
nn.init.zeros_(c.bias.data)
if stride != 1 or in_channels != filters3:
s = conv_fn(in_channels, filters3, 1, stride, bias=bias)
nn.init.kaiming_normal_(s.weight.data)
if bias:
nn.init.zeros_(s.bias.data)
if with_batch_normalization:
shortcut_layers = nn.Sequential(s, bn_ctor(filters3))
else:
shortcut_layers = s
else:
shortcut_layers = None
relu = nn.ReLU(inplace=True)
if with_batch_normalization:
core_layers = nn.Sequential(a, bn_ctor(filters1), relu, b,
bn_ctor(filters2), relu, c,
bn_ctor(filters3))
else:
core_layers = nn.Sequential(a, relu, b, relu, c)
self._core_layers = core_layers
self._shortcut_layers = shortcut_layers
[docs] def forward(self, inputs):
core = self._core_layers(inputs)
if self._shortcut_layers:
shortcut = self._shortcut_layers(inputs)
else:
shortcut = inputs
return torch.relu_(core + shortcut)
[docs] def calc_output_shape(self, input_shape):
x = torch.zeros(1, *input_shape)
y = self.forward(x)
return y.shape[1:]
def _masked_softmax(logits, mask, dim=-1):
if mask is not None:
logits.masked_fill_(mask, -float('inf'))
return nn.functional.softmax(logits, dim=dim)
[docs]class Lambda(nn.Module):
"""Wrap a function as an nn.Module."""
def __init__(self, func):
"""
Args:
func (Callable): a function that calculate the output given the input.
It should be parameterless.
"""
super().__init__()
self._func = func
[docs] def forward(self, x):
return self._func(x)
[docs]class GFT(nn.Module):
"""Guided Feature Transformation.
This class implements the GFT model proposed in the following paper:
`Yu et al. Guided Feature Transformation (GFT): A Neural Language Grounding
Module for Embodied Agents, CoRL 2018 <https://arxiv.org/pdf/1805.08329.pdf>`_
"""
def __init__(self, num_transformations, image_channels, language_dim):
super().__init__()
self._t_layers = nn.ModuleList([
FC(language_dim, (1 + image_channels) * image_channels)
for k in range(num_transformations)
])
self._ones = torch.ones(1, 1, 1)
[docs] def forward(self, input):
"""
Args:
input (tuple): the tuple of image features and sentence embedding.
Returns:
Tensor: same shape as input[0]
"""
image, sentence = input
batch_size, channels = image.shape[:2]
# [B, C, W*H]
cnn_out = image.view(batch_size, channels, -1)
## compute K transformation matrices
ts = [
l(sentence).view(batch_size, channels, channels + 1)
for l in self._t_layers
]
ones = self._ones.expand(batch_size, 1, cnn_out.shape[-1])
for t in ts:
# [B, C+1, W*H]
cnn_out = torch.cat((cnn_out, ones), dim=1)
# [B, C, W*H] <= [B, C, C+1] * [B, C+1, W*H]
cnn_out = torch.relu_(torch.matmul(t, cnn_out))
return cnn_out.reshape(*image.shape)
[docs] def reset_parameters(self):
for l in self._t_layers:
l.reset_parameters()
[docs]class GetFields(ElementwiseLayerBase):
"""Get the fields from a nested input."""
def __init__(self, field_nest=None, **fields):
"""
Args
field_nest (nested str): the path of the fields to be retrieved. Each str
in ``fields`` represents a path to the field with '.' separating
the field name at different level.
fields (str): A simpler way of specifying ``field_nest`` when it is
a dict. ``GetFields(a="field_a", b="field_b")`` is equivalent to
``GetFields(dict(a="field_a", b="field_b"))``.
"""
super().__init__()
if field_nest is not None:
assert not fields
fields = field_nest
self._fields = fields
[docs] def forward(self, input):
return alf.nest.map_structure(
lambda path: alf.nest.get_field(input, path), self._fields)
[docs]class ReplicationPad2d(nn.Module):
r"""Pad the input tensor using replication of the input boundary.
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
This is same as torch.nn.ReplicationPad2d except that this implementation
can handle input of any dtype, while torch.nn one can only handle float dtype.
Args:
padding (int, tuple): the size of the padding. If is `int`, uses the same
padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
:math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)
Shape:
- Input: :math:`(N, C, H_{in}, W_{in})`
- Output: :math:`(N, C, H_{out}, W_{out})` where
:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
"""
def __init__(self, padding):
super().__init__()
if type(padding) == int:
padding = (padding, padding, padding, padding)
self._padding = padding
self._h = -1
self._w = -1
[docs] def forward(self, input):
h, w = input.shape[-2:]
left, right, top, bottom = self._padding
if h != self._h:
yindex = torch.arange(-top, h + bottom)
yindex[:top] = 0
yindex[-bottom:] = h - 1
yindex = yindex.unsqueeze(-1)
self._yindex = yindex
self._h = h
else:
yindex = self._yindex
if w != self._w:
xindex = torch.arange(-left, w + right)
xindex[:left] = 0
xindex[-right:] = w - 1
self._xindex = xindex
self._w = w
else:
xindex = self._xindex
return input[..., yindex, xindex]
[docs]class RandomCrop(nn.Module):
r"""Perform random crop independently for each image in the batch.
Note that ``torchvision.transforms.RandomCrop`` is different in that it
applies the same random crop for all the images in the batch.
Each result image is a random crop of the padded input image. The padded
pixels are from the neareat pixel from the boundary.
Args:
size: a tuple of desired height and width. If is `int`, uses the same
height and width.
padding: the size of the padding. If is `int`, uses the same
padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
:math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`).
"""
def __init__(self,
size: Union[int, Tuple[int]],
padding: Union[int, Tuple[int]] = 0):
super().__init__()
if type(size) == int:
size = (size, size)
self._size = size
if type(padding) == int:
padding = (padding, padding, padding, padding)
self._padding = padding
[docs] def forward(self, input: Tensor) -> Tensor:
"""
Args:
input: shape is [B, C, H, W]
Returns:
a tensor of shape [B, C, h, w], where ``h, w=size``
"""
assert input.ndim == 4, "input.ndim should be 4"
h, w = self._size
left, right, top, bottom = self._padding
B, C, H, W = input.shape
assert h <= H + top + bottom and w <= W + left + right, (
"input size is too small: %s vs %s" % ((H, W), (h, w)))
starty = torch.randint(-top, H + bottom - h + 1, (B, )).reshape(B, 1)
startx = torch.randint(-left, W + right - w + 1, (B, )).reshape(
B, 1, 1)
# [B, h, 1]
y = (starty + torch.arange(h)).clamp_(min=0, max=H - 1).unsqueeze(-1)
# [B, 1, w]
x = (startx + torch.arange(w)).clamp_(min=0, max=W - 1)
# [B, 1, 1]
b = torch.arange(B).reshape(B, 1, 1)
# The alternative way of input[b, c, y, x] would use a lot more memory
return input.transpose(0, 1)[:, b, y, x].transpose(0, 1)
[docs]class Sum(nn.Module):
"""Sum over given dimension(s).
Note that batch dimension is not counted for dim. This means that
dim=0 means the dimension after batch dimension.
"""
def __init__(self, dim):
"""
Args:
dim (int|tuple[int]): the dimension(s) to be summed.
"""
super().__init__()
dim = alf.nest.map_structure(lambda d: d + 1 if d >= 0 else d, dim)
self._dim = dim
[docs] def forward(self, input):
return input.sum(dim=self._dim)
[docs] def make_parallel(self, n: int):
"""Create a Sum layer to handle parallel batch.
It is assumed that a parallel batch has shape [B, n, ...] and both the
batch dimension and replica dimension are not counted for ``dim``
Args:
n (int): the number of replicas.
Returns:
a ``Sum`` layer to handle parallel batch.
"""
return Sum(self._dim)
[docs]class AddN(ElementwiseLayerBase):
"""Add several tensors"""
def __init__(self):
super().__init__()
[docs] def forward(self, input: Iterable[torch.Tensor]):
"""
Args:
input (Iterable[Tensor]): a sequence of tensors to be summed
Returns:
Tensor: the sum of all the tensors
"""
return sum(input)
[docs]def reset_parameters(module):
"""Reset the parameters for ``module``.
Args:
module (nn.Module):
Returns:
None
Raises:
ValueError: fail to reset the parameters for ``module``
"""
if hasattr(module, 'reset_parameters'):
module.reset_parameters()
elif isinstance(module, nn.Sequential):
for l in module:
reset_parameters(l)
elif isinstance(module, nn.Module):
if len(list(module.parameters())) > 0:
raise ValueError(
"Cannot reset_parameter for layer type %s." % type(module))
[docs]class Detach(ElementwiseLayerBase):
"""Detach nested Tensors."""
def __init__(self):
super().__init__()
[docs] def forward(self, input):
return common.detach(input)
[docs]class Scale(ElementwiseLayerBase):
def __init__(self, scale):
super().__init__()
self._scale = scale
[docs] def forward(self, input):
return self._scale * input
[docs]class ScaleGradient(ElementwiseLayerBase):
"""Scales the gradient of input for the backward pass.
Args:
scale (float): a scalar factor to be multiplied to the gradient
of `tensor`.
"""
def __init__(self, scale: float):
super().__init__()
self._scale = scale
[docs] def forward(self, input):
# (1 - self._scale) * input.detach() + self._scale * input
return torch.lerp(input.detach(), input, self._scale)
[docs]@alf.configurable
class SummarizeGradient(ElementwiseLayerBase):
def __init__(self, name):
"""A layer for summarizing the gradient of the input tensor.
Summarize the gradient of the input tensor. Always first cloning the
input tensor and then setting ``requires_grad=True`` for the cloned
tensor to enable gradient calculation for summarization.
Args:
name (str): used to describe the name of the summary, after the
tag 'tensor_gradient'.
Returns:
cloned ``tensor``: with ``requires_grad`` set to True and gradient
summarization hook registered.
"""
super().__init__()
self._name = name
[docs] def forward(self, x):
# clone the input tensor ``x`` to avoid impacts on training in the case
# where ``x`` does not require gradient
y = x.clone()
# explicitly turn on gradient calculation in order to summarize gradient
y.requires_grad = True
y = summarize_tensor_gradients(
"tensor_gradient/{}".format(self._name), y, clone=False)
return y
[docs]class Branch(nn.Module):
"""Apply multiple modules on the same input.
Example:
.. code-block:: python
net = Branch((module1, module2))
y = net(x)
is equivalent to the following:
.. code-block:: python
y = module1(x), module2(x)
"""
def __init__(self, *modules, **named_modules):
"""
Args:
modules (nested nn.Module): a nest of ``torch.nn.Module``. Note that
``Branch(module_a, module_b)`` is equivalent to
``Branch((module_a, module_b))``
named_modules (nn.Module | Callable): a simpler way of specifying
a dict of modules. ``Branch(a=model_a, b=module_b)``
is equivalent to ``Branch(dict(a=module_a, b=module_b))``
"""
super().__init__()
if modules:
assert not named_modules
if len(modules) == 1:
modules = modules[0]
else:
modules = named_modules
has_network = any(
alf.nest.flatten(
alf.nest.map_structure(
lambda m: isinstance(m, alf.networks.Network), modules)))
assert not has_network, (
"modules should not contain alf.networks.Network. "
"Try alf.networks.Branch instead.")
self._networks = modules
if alf.nest.is_nested(modules):
# make it a nn.Module so its parameters can be picked up by the framework
self._nets = alf.nest.utils.make_nested_module(modules)
[docs] def forward(self, inputs):
return alf.nest.map_structure(lambda net: net(inputs), self._networks)
[docs] def reset_parameters(self):
alf.nest.map_structure(reset_parameters, self._networks)
[docs] def make_parallel(self, n: int):
"""Create a parallelized version of this network.
Args:
n (int): the number of copies
Returns:
the parallelized version of this network
"""
new_networks = alf.nest.map_structure(
lambda net: make_parallel_net(net, n), self._networks)
return Branch(new_networks)
[docs]class Sequential(nn.Module):
"""A more flexible Sequential than torch.nn.Sequential.
``alf.layers.Sequential`` is similar to ``alf.nn.Sequential``, but does not
accept stateful ``alf.nn.Network`` as its elements.
All the modules provided through ``modules`` and ``named_modules`` are calculated
sequentially in the same order as they appear in the call to ``Sequential``.
Typically, each module takes the result of the previous module as its input
(or the input to the Sequential if it is the first module), and the result of
the last module is the output of the Sequential. But we also allow more
flexibilities as shown in example 2.
Example 1:
.. code-block:: python
net = Sequential(module1, module2)
y = net(x)
is equivalent to the following:
.. code-block:: python
z = module1(x)
y = module2(z)
Example 2:
.. code-block:: python
net = Sequential(
module1, a=module2, b=(('input', 'a'), module3), output=('a', 'b'))
output = net(input, state)
is equivalent to the following:
.. code-block:: python
_ = module1(input)
a = module2(_)
b = module3((input, a))
output = (a, b)
"""
def __init__(self, *modules, output='', **named_modules):
"""
Args:
modules (Callable | (nested str, Callable)):
The ``Callable`` can be a ``torch.nn.Module``, stateless ``alf.nn.Network``
or plain ``Callable``. Optionally, their inputs can be specified
by the first element of the tuple. If input is not provided, it is
assumed to be the result of the previous module (or input to this
``Sequential`` for the first module). If input is provided, it
should be a nested str. It will be used to retrieve results from
the dictionary of the current ``named_results``. For modules
specified by ``modules``, because no ``named_modules`` has been
invoked, ``named_results`` is ``{'input': input}``.
named_modules (Callable | (nested str, Callable)):
The ``Callable`` can be a ``torch.nn.Module``, stateless ``alf.nn.Network``
or plain ``Callable``. Optionally, their inputs can be specified
by the first element of the tuple. If input is not provided, it is
assumed to be the result of the previous module (or input to this
``Sequential`` for the first module). If input is provided, it
should be a nested str. It will be used to retrieve results from
the dictionary of the current ``named_results``. ``named_results``
is updated once the result of a named module is calculated.
output (nested str): if not provided, the result from the last module
will be used as output. Otherwise, it will be used to retrieve
results from ``named_results`` after the results of all modules
have been calculated.
"""
super().__init__()
named_elements = list(zip([''] * len(modules), modules)) + list(
named_modules.items())
modules = []
inputs = []
outputs = []
simple = True
is_nested_str = lambda s: all(
map(lambda x: type(x) == str, alf.nest.flatten(s)))
self._networks = []
# pytorch nn.Moddule needs to use ModuleList to keep track of parameters
self._nets = nn.ModuleList()
for i, (out, element) in enumerate(named_elements):
input = ''
if isinstance(element, tuple) and len(element) == 2:
input, module = element
else:
module = element
if not (isinstance(module, Callable) and is_nested_str(input)):
raise ValueError(
"Argument %s is not in the form of Callable "
"or (nested str, Callable): %s" % (out or str(i), element))
if isinstance(module, alf.networks.Network):
assert not alf.nest.flatten(module.state_spec), (
"Network element of layers.Sequential should be stateless. "
"Use networks.Sequential instead")
inputs.append(input)
outputs.append(out)
self._networks.append(module)
if isinstance(module, nn.Module):
self._nets.append(module)
if out or input:
simple = False
if simple:
self.forward = self._forward_simple
else:
self.forward = self._forward_complex
self._output = output
self._inputs = inputs
self._outputs = outputs
def _forward_simple(self, input):
for module in self._networks:
if isinstance(module, alf.networks.Network):
input = module(input)[0]
else:
input = module(input)
return input
def _forward_complex(self, input):
var_dict = {'input': input}
for i, net in enumerate(self._networks):
if self._inputs[i]:
input = get_nested_field(var_dict, self._inputs[i])
if isinstance(net, alf.networks.Network):
input = net(input)[0]
else:
input = net(input)
if self._outputs[i]:
var_dict[self._outputs[i]] = input
if self._output:
input = get_nested_field(var_dict, self._output)
return input
[docs] def reset_parameters(self):
alf.nest.map_structure(reset_parameters, self._networks)
def __getitem__(self, i):
return self._networks[i]
[docs] def make_parallel(self, n: int):
"""Create a parallelized version of this network.
Args:
n (int): the number of copies
Returns:
the parallelized version of this network
"""
new_networks = []
new_named_networks = {}
for net, input, output in zip(self._networks, self._inputs,
self._outputs):
pnet = alf.layers.make_parallel_net(net, n)
if not output:
new_networks.append((input, pnet))
else:
new_named_networks[output] = (input, pnet)
return Sequential(
*new_networks, output=self._output, **new_named_networks)
[docs]def make_parallel_net(module, n: int):
"""Make a parallelized version of ``module``.
A parallel network has ``n`` copies of network with the same structure but
different independently initialized parameters. The parallel network can
process a batch of the data with shape [batch_size, n, ...] using ``n``
networks with same structure.
If ``module`` has member function make_parallel, it will be called to make
the parallel network. Otherwise, it will creates a ``NaiveParallelLayer``,
which simply making ``n`` copies of ``module`` and use a loop to call them
in ``forward()``.
Examples:
Applying parallel net on same input:
.. code-block:: python
pnet = make_parallel_net(net, n)
# replicate input.
# pinput will have shape [batch_size, n, ...], if input has shape [batch_size, ...]
pinput = make_parallel_input(input, n)
poutput = pnet(pinput)
If you already have parallel input with shape [batch_size, n, ...], you can
omit the call to ``make_parallel_input`` in the above code.
Args:
module (Network | nn.Module | Callable): the network to be parallelized.
n (int): the number of copies
Returns:
the parallelized network.
"""
if hasattr(module, 'make_parallel'):
return module.make_parallel(n)
else:
logging.warning(
"%s does not have make_parallel. A naive parallel layer "
"will be created." % str(module))
return NaiveParallelLayer(module, n)
[docs]class NaiveParallelLayer(nn.Module):
def __init__(self, module: Union[nn.Module, Callable], n: int):
"""
A parallel network has ``n`` copies of network with the same structure but
different indepently initialized parameters.
``NaiveParallelLayer`` creates ``n`` independent networks with the same
structure as ``network`` and evaluate them separately in a loop during
``forward()``.
Args:
module (nn.Module | Callable): the parallel network will have ``n`
copies of ``module``.
n (int): ``n`` copies of ``module``
"""
super().__init__()
if isinstance(module, nn.Module):
self._networks = nn.ModuleList(
[copy.deepcopy(module) for i in range(n)])
for net in self._networks:
reset_parameters(net)
else:
self._networks = [module] * n
self._n = n
[docs] def forward(self, inputs):
"""Compute the output.
Args:
inputs (nested torch.Tensor): its shape is ``[B, n, ...]``
Returns:
output (nested torch.Tensor): its shape is ``[B, n, ...]``
"""
outputs = []
for i in range(self._n):
inp = alf.nest.map_structure(lambda x: x[:, i, ...], inputs)
ret = self._networks[i](inp)
outputs.append(ret)
if self._n > 1:
output = alf.nest.map_structure(
lambda *tensors: torch.stack(tensors, dim=1), *outputs)
else:
output = alf.nest.map_structure(lambda tensor: tensor.unsqueeze(1),
outputs[0])
return output
[docs] def reset_parameters(self):
for i in range(self._n):
reset_parameters(self._networks[i])
[docs]def make_parallel_spec(specs, n: int):
"""Make the spec for parallel network.
Args:
specs (nested TensorSpec): the input spec for the non-parallelized network
n (int): the number of copies of the parallelized network
Returns:
input tensor spec for the parallelized network
"""
def _make_spec(spec):
if type(spec) == alf.TensorSpec:
return alf.TensorSpec((n, ) + spec.shape, spec.dtype)
else: # BoundedTensorSpec
return alf.BoundedTensorSpec((n, ) + spec.shape, spec.dtype,
spec.minimum, spec.maximum)
return map_structure(_make_spec, specs)
[docs]class AMPWrapper(nn.Module):
"""Wrap a layer to run in a given AMP context.
Args:
enabled: whether to enable AMP autocast
net: the wrapped network
"""
def __init__(self, enabled: bool, net: nn.Module):
super().__init__()
self._net = net
self._enabled = enabled
[docs] def forward(self, input):
if torch.is_autocast_enabled() and not self._enabled:
input = alf.nest.map_structure(
lambda x: x.float() if x.dtype.is_floating_point else x, input)
with torch.cuda.amp.autocast(self._enabled):
return self._net(input)
[docs]class SimpleAttention(nn.Module):
"""Simple Attention Module."""
def __init__(self):
super().__init__()
[docs] def forward(self, query, key, value):
"""Simple attention computation based on the inputs.
Args:
query (Q): shape [B, head, M, d]
key (K): shape [B, head, N, d]
value (V): shape [B, head, N, d]
where B denotes the batch size, head denotes the number of heads,
N the number of entities, and d the feature dimension.
Return:
- the attended results computed as: softmax(QK^T/sqrt(d))V,
with the shape [B, head, M, d]
- the attention weight, with the shape [B, head, M, N]
"""
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(
torch.tensor(d_k))
# [B, head, M, N]
attention_weight = F.softmax(scores, dim=-1)
# [B, head, M, d]
output = torch.matmul(attention_weight, value)
return output, attention_weight