Source code for alf.norm_layers

# Copyright (c) 2021 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Union
import types

import alf
from alf.utils.common import warning_once


class _NormBase(nn.Module):
    """Base of BatchNorm supporting RNN."""

    def __init__(self,
                 num_features: int,
                 eps: float = 1e-5,
                 momentum: float = 0.1,
                 affine: bool = True,
                 fixed_weight_norm=False,
                 use_bias: bool = True,
                 track_running_stats: bool = True):
        super().__init__()
        self._num_features = num_features
        self._eps = eps
        self._momentum = momentum
        self._affine = affine
        self._track_running_stats = track_running_stats
        if affine:
            self._weight = nn.Parameter(torch.Tensor(num_features))
            use_bias = True
            if fixed_weight_norm:
                self._weight.opt_args = dict(
                    max_norm=math.sqrt(num_features),
                    fixed_norm=fixed_weight_norm)
        else:
            self._weight = None
        if use_bias:
            if self._weight is None:
                # pytorch has a bug which cannot handle the case that weight is
                # None but bias is not. So we have to provide a fixed weight.
                self._weight = nn.Parameter(
                    torch.Tensor(num_features), requires_grad=False)
            self._bias = nn.Parameter(torch.Tensor(num_features))
        else:
            self._bias = None
        self._use_bias = use_bias
        self._running_means = []
        self._running_vars = []
        self._num_batches_tracked = []
        self.set_max_steps(1)
        self._current_step = 0
        self.reset_parameters()
        self._clamped = False

    def set_max_steps(self, max_steps: int):
        """Set max steps to keeping running statistics.

        Args:
            max_steps: the maximum steps for which the batch norm running statistics
                are maintained.
        """
        self._max_steps = max_steps
        if not self._track_running_stats:
            return
        for i in range(len(self._running_means), max_steps):
            self._running_means.append(torch.zeros(self._num_features))
            self.register_buffer('_running_means%s' % i,
                                 self._running_means[i])
            self._running_vars.append(torch.ones(self._num_features))
            self.register_buffer('_running_vars%s' % i, self._running_vars[i])
            self._num_batches_tracked.append(torch.zeros((), dtype=torch.long))
            self.register_buffer('_num_batches_tracked%s' % i,
                                 self._num_batches_tracked[i])

    def set_current_step(self, current_step: Union[torch.Tensor, int]):
        """Use and/or update the running statistics at current_step for normalization.

        Args:
            current_step: the current step. If it is a Tensor, it should be a 1D
                int64 Tensor of shape [batch_size,]. And each of its element
                means the current step for the corresponding sample in a batch.
        """
        if not self._track_running_stats:
            return
        self._clamped = False
        if type(current_step) == int:
            if current_step >= self._max_steps:
                warning_once("current_step should be smaller than "
                             "max_steps. Got %s. Will be clamped to %s" %
                             (current_step, self._max_steps - 1))
                current_step = min(current_step, self._max_steps - 1)
                self._clamped = True
        elif isinstance(current_step, torch.Tensor):
            assert 0 <= current_step.ndim <= 1
            if torch.any(current_step >= self._max_steps):
                warning_once("current_step should be smaller than "
                             "max_steps. Got %s. Will be clamped to %s" %
                             (current_step.max(), self._max_steps - 1))
                current_step = current_step.clamp(max=self._max_steps - 1)
                self._clamped = True
        self._current_step = current_step

    def reset_parameters(self):
        """Reset the parameters."""
        if self._track_running_stats:
            for i in range(self._max_steps):
                self._running_means[i].zero_()
                self._running_vars[i].fill_(1)
                self._num_batches_tracked[i].zero_()
        if self._weight is not None:
            nn.init.ones_(self._weight)
        if self._bias is not None:
            nn.init.zeros_(self._bias)

    def forward(self, input: torch.Tensor):
        self._check_input_dim(input)

        if self.training or not self._track_running_stats:
            if self._track_running_stats:
                current_step = self._current_step
                if isinstance(current_step,
                              torch.Tensor) and current_step.ndim != 0:
                    assert torch.all(current_step == current_step[0]), (
                        "all current_steps must be same for training.")
                    current_step = current_step[0]
                current_step = int(current_step)
                running_mean = self._running_means[current_step]
                running_var = self._running_vars[current_step]
                if not self._clamped:
                    num_batches_tracked = self._num_batches_tracked[
                        current_step]
                    num_batches_tracked.add_(1)
                    if self._momentum is None:  # use cumulative moving average
                        exponential_average_factor = 1.0 / float(
                            num_batches_tracked)
                    else:  # use exponential moving average
                        exponential_average_factor = self._momentum
                else:
                    exponential_average_factor = 0.0
            else:
                running_mean = None
                running_var = None
                exponential_average_factor = 0.0
            return F.batch_norm(
                input,
                running_mean,
                running_var,
                self._weight,
                self._bias,
                # whether the mini-batch stats should be used for normalization
                # rather than the running stats.
                # If current_step is out of limit, we will use the running stats
                # for max_steps - 1 to normalize the batch so that we can keep
                # training and eval consistent.
                not self._clamped,
                exponential_average_factor,
                self._eps)
        else:  # not training and tracking running stats
            running_means = torch.stack(
                self._running_means, dim=0)[self._current_step]
            running_vars = torch.stack(
                self._running_vars, dim=0)[self._current_step]
            if running_means.ndim == 1:
                running_means = running_means[None, :].expand(input.shape[:2])
                running_vars = running_vars[None, :].expand(input.shape[:2])
            running_means = running_means.reshape(-1)
            running_vars = running_vars.reshape(-1)
            weight = self._weight
            bias = self._bias
            if weight is not None:
                weight = weight[None, :].expand(input.shape[:2]).reshape(-1)
            if bias is not None:
                bias = bias[None, :].expand(input.shape[:2]).reshape(-1)
            batch_size = input.shape[0]
            input = input.reshape(1, -1, *input.shape[2:])
            y = F.batch_norm(
                input,
                running_means,
                running_vars,
                weight,
                bias,
                # whether the mini-batch stats should be used for normalization
                # rather than the running stats.
                False,
                0.0,  # exponential_average_factor
                self._eps)
            y = y.reshape(batch_size, -1, *y.shape[2:])
            return y


[docs]@alf.configurable class BatchNorm1d(_NormBase): r"""Batch Normalization over a 2D or 3D input. For detail about Batch Normalization, see https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html The main difference is that this implementation supports using BN for RNN. The reason is that for RNN, the normalization statics can be dramatically different for different step of RNN. Hence we need to maintain different running statistics for different step of RNN. The following example shows how to use it, assuming ``rnn`` is a ``Network`` which contains some alf.layers.BatchNorm layers. .. code-block:: python prepare_rnn_batch_norm(rnn) rnn.set_batch_norm_max_steps(5) for i in range(t): rnn.set_batch_norm_current_step(i) y, state = rnn(input[i], state) Note that ``set_batch_norm_current_step()`` also accepts Tensor as its argument. In that case, it means that the current step for each sample in a batch. Args: num_features: :math:`C` from an expected input of size :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)` eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Can be set to ``None`` for cumulative moving average (i.e. simple average). Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` fixed_weight_norm: whether to fix the norm of the affine weight parameter. The norm will be fixed at ``sqrt(num_features). use_bias: whether to use bias. Note that if ``affine`` is True, this argument is ignored and bias will be used. track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics, and initializes statistics buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` Shape: - Input: :math:`(N, C)` or :math:`(N, C, L)` - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) """ def _check_input_dim(self, input): if input.dim() != 2 and input.dim() != 3: raise ValueError('expected 2D or 3D input (got {}D input)'.format( input.dim()))
[docs]@alf.configurable class BatchNorm2d(_NormBase): r"""Applies Batch Normalization over a 4D input. For detail about Batch Normalization, see https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html The main difference is that this implementation supports using BN for RNN. The reason is that for RNN, the normalization statics can be dramatically different for different step of RNN. Hence we need to maintain different running statistics for different step of RNN. The following example shows how to use it, assuming ``rnn`` is a ``Network`` which contains some alf.layers.BatchNorm layers. .. code-block:: python prepare_rnn_batch_norm(rnn) # Only need to call once in the lifetime of rnn rnn.set_batch_norm_max_steps(5) # Only need to call once in the lifetime of rnn for i in range(t): rnn.set_batch_norm_current_step(i) y, state = rnn(input[i], state) Args: num_features: :math:`C` from an expected input of size :math:`(N, C, H, W)` eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Can be set to ``None`` for cumulative moving average (i.e. simple average). Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` fixed_weight_norm: whether to fix the norm of the affine weight parameter. The norm will be fixed at ``sqrt(num_features). use_bias: whether to use bias. Note that if ``affine`` is True, this argument is ignored and bias will be used. track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics, and initializes statistics buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` Shape: - Input: :math:`(N, C, H, W)` - Output: :math:`(N, C, H, W)` (same shape as input) """ def _check_input_dim(self, input): if input.dim() != 4: raise ValueError('expected 4D input (got {}D input)'.format( input.dim()))
[docs]def set_batch_norm_max_steps(module, max_steps: int): """Set max_steps for all batch norm layers in ``module``. Args: max_steps: the maximum steps for which the batch norm running statistics are maintained. """ for bn in module._all_bns: bn.set_max_steps(max_steps)
[docs]def set_batch_norm_current_step(module: nn.Module, current_step: Union[torch.Tensor, int]): """Set current_step for all batch norm layers in ``module``. Args: current_step: the current step for RNN. If it is a Tensor, it means that the current step for each sample in a batch. """ for bn in module._all_bns: bn.set_current_step(current_step)
[docs]def prepare_rnn_batch_norm(module: nn.Module) -> bool: """Prepare an RNN network ``module`` to use alf.layers.BatchNorm layers. It will report error if any nn.BatchNorm layer is found within ``module`` Returns: True if alf.layers.BatchNorm layers have been found False otherwise. """ bns = set() todo = [("", module)] visited = set() while len(todo) > 0: path, m = todo.pop() if isinstance(m, (BatchNorm1d, BatchNorm2d)): bns.add(m) elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): warning_once( "RNN may not perform well with torch.nn.BatchNorm layer " "(at %s). Consider using alf.layers.BatchNorm instead." % path) elif isinstance(m, nn.Module): for name, submodule in m.named_children(): if submodule not in visited: todo.append((path + '.' + name, submodule)) visited.add(submodule) module._all_bns = bns module.set_batch_norm_max_steps = types.MethodType( set_batch_norm_max_steps, module) module.set_batch_norm_current_step = types.MethodType( set_batch_norm_current_step, module) return len(bns) > 0
[docs]class ParamLayerNorm(nn.Module): """ParamLayerNorm, adapted from ``torch.nn.modules.LayerNorm`` """ def __init__(self, n_groups: int, output_channels: int, eps: float = 1e-5): """A general Layer Normalization layer that does not maintain learnable affine parameters (weight and bias), but accepts both from users. If ``n_groups`` is greater than 1, it performs parallel Layer Normalization operation. Args: n_groups: number of parallel groups output_channels: output size for FC layers, output channel size for conv layers. eps: refer to nn.GroupNorm """ super().__init__() self._n_groups = n_groups self._output_channels = output_channels self._eps = eps self._set_weight(torch.ones(n_groups, self.weight_length)) self._set_bias(torch.zeros(n_groups, self.bias_length)) self._param_length = None @property def weight(self): """Get stored weight tensor or batch of weight tensors.""" return self._weight @property def bias(self): """Get stored bias tensor or batch of bias tensors.""" return self._bias @property def output_channels(self): """Get the n_element of a single weight tensor. """ return self._output_channels @property def weight_length(self): """Get the n_element of a single weight tensor. """ return self._output_channels @property def bias_length(self): """Get the n_element of a single bias tensor. """ return self._output_channels @property def param_length(self): """Get total number of parameters for all layers. """ if self._param_length is None: self._param_length = self.weight_length + self.bias_length return self._param_length
[docs] def set_parameters(self, theta: torch.Tensor, reinitialize: bool = False): """Distribute parameters to corresponding parameters. Args: theta: with shape ``[D] (groups=1)`` or ``[B, D] (groups=B)``, where the meaning of the symbols are: - ``B``: batch size - ``D``: length of parameters, should be self.param_length When the shape of inputs is ``[D]``, it will be unsqueezed to ``[1, D]``. reinitialize: whether to reinitialize parameters of each layer. """ if theta.ndim == 1: theta = theta.unsqueeze(0) assert (theta.ndim == 2 and theta.shape[0] == self._n_groups and (theta.shape[1] == self.param_length)), ( "Input theta has wrong shape %s. Expecting shape (%d, %d)" % (theta.shape, self._n_groups, self.param_length)) weight = theta[:, :self.weight_length] self._set_weight(weight, reinitialize=reinitialize) bias = theta[:, self.weight_length:] self._set_bias(bias, reinitialize=reinitialize)
def _set_weight(self, weight: torch.Tensor, reinitialize: bool = False): """Store a weight tensor or batch of weight tensors. Args: weight: with shape ``[B, D]`` where the mining of the symbols are: - ``B``: batch size - ``D``: length of weight vector, should be self.weight_length reinitialize: whether to reinitialize self._weight """ assert (weight.ndim == 2 and weight.shape[0] == self._n_groups and (weight.shape[1] == self.weight_length)), ( "Input weight has wrong shape %s. Expecting shape (%d, %d)" % (weight.shape, self._n_groups, self.weight_length)) if reinitialize: weight = torch.ones(self._n_groups, self.weight_length) self._weight = weight.reshape(-1) # [n * weight_length] def _set_bias(self, bias: torch.Tensor, reinitialize: bool = False): """Store a bias tensor or batch of bias tensors. Args: bias: with shape ``[B, D]`` where the meaning of the symbols are: - ``B``: batch size - ``D``: length of bias vector, should be self.bias_length reinitialize: whether to reinitialize self._bias """ assert (bias.ndim == 2 and bias.shape[0] == self._n_groups and (bias.shape[1] == self.bias_length)), ( "Input bias has wrong shape %s. Expecting shape (%d, %d)" % (bias.shape, self._n_groups, self.bias_length)) if reinitialize: bias = torch.zeros(self._n_groups, self.bias_length) self._bias = bias.reshape(-1) # [n * bias_length] def _preprocess_input(self, inputs): raise NotImplementedError
[docs] def forward(self, inputs: torch.Tensor, keep_group_dim: bool = True): """Forward Args: inputs: refer to ``_preprocess_input`` of subclass for detailed description. keep_group_dim: whether to keep group dimension or not. Returns: torch.Tensor: for BatchNorm1d, with shape ``[B, n, D]`` or ``[B, n*D]``, for BatchNorm2d, with shape ``[B, n, C, H, W]`` or ``[B, n*C, H, W]``. """ inputs = self._preprocess_input(inputs) res = F.group_norm(inputs, self._n_groups, self.weight, self.bias, self._eps) if self._n_groups > 1 and keep_group_dim: res = res.reshape(inputs.shape[0], self._n_groups, -1, *inputs.shape[2:]) # [B, n, ...] return res
[docs]class ParamLayerNorm1d(ParamLayerNorm): def _preprocess_input(self, inputs: torch.Tensor): """Check inputs shape and preprocess for LayerNorm1d. Args: inputs: with shape ``[B, D] (groups=1)`` or ``[B, n, D] (groups=n)``, where the meaning of the symbols are: - ``B``: batch size - ``n``: number of replicas - ``D``: input dimension When the shape of inputs is ``[B, D]``, all the n linear operations will take inputs as the same shared inputs. When the shape of inputs is ``[B, n, D]``, each linear operator will have its own input data by slicing inputs. Returns: torch.Tensor: with shape ``[B, n*D]`` """ if inputs.ndim == 2: # case 1: non-parallel inputs assert inputs.shape[1] == self.output_channels, ( "Input inputs has wrong shape %s. Expecting (B, %d)" % (inputs.shape, self.output_channels)) inputs = inputs.repeat(1, self._n_groups) # [B, n*D] # inputs = inputs.unsqueeze(0).expand(self._n_groups, *inputs.shape) elif inputs.ndim == 3: # case 2: parallel inputs assert ( inputs.shape[1] == self._n_groups and inputs.shape[2] == self.output_channels), ( "Input inputs has wrong shape %s. Expecting (B, %d, %d)" % (inputs.shape, self._n_groups, self.output_channels)) # [B, n*D] inputs = inputs.reshape(-1, self._n_groups * self.output_channels) else: raise ValueError("Wrong inputs.ndim=%d" % inputs.ndim) return inputs
[docs]class ParamLayerNorm2d(ParamLayerNorm): def _preprocess_input(self, inputs: torch.Tensor): """Check inputs shape and preprocess for LayerNorm2d. Args: inputs: with shape ``[B, C, H, W] (groups=1)`` or ``[B, n, C, H, W] (groups=n)``, where the meaning of the symbols are: - ``B``: batch size - ``n``: number of replicas - ``C``: number of channels - ``H``: image height - ``W``: image width. When the shape of img is ``[B, C, H, W]``, all the n 2D Conv operations will take img as the same shared input. When the shape of img is ``[B, n, C, H, W]``, each 2D Conv operator will have its own input data by slicing img. Returns: torch.Tensor with shape ``[B, n*C, H, W]`` """ if self._n_groups == 1: # non-parallel layer assert (inputs.ndim == 4 and inputs.shape[1] == self.output_channels), ( "Input img has wrong shape %s. Expecting (B, %d, H, W)" % (inputs.shape, self.output_channels)) else: # parallel layer if inputs.ndim == 4: if inputs.shape[1] == self.output_channels: # case 1: non-parallel input inputs = inputs.repeat(1, self._n_groups, 1, 1) else: # case 2: parallel input assert inputs.shape[ 1] == self._n_groups * self.output_channels, ( "Input img has wrong shape %s. Expecting (B, %d, H, W) or (B, %d, H, W)" % (inputs.shape, self.output_channels, self._n_groups * self.output_channels)) elif inputs.ndim == 5: # case 3: parallel input with unmerged group dim assert ( inputs.shape[1] == self._n_groups and inputs.shape[2] == self.output_channels ), ("Input img has wrong shape %s. Expecting (B, %d, %d, H, W)" % (inputs.shape, self._n_groups, self.output_channels)) # merge group and channel dim inputs = inputs.reshape(inputs.shape[0], inputs.shape[1] * inputs.shape[2], *inputs.shape[3:]) else: raise ValueError("Wrong img.ndim=%d" % inputs.ndim) return inputs