# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl import logging
import math
import torch
import torch.nn as nn
import alf
import alf.utils.math_ops as math_ops
def _is_elementwise_op(op):
"""Check whether ``op`` is an elementwise operation."""
x = torch.randn(10, 20)
x1 = x.clone().reshape(20, 10)
y = op(x)
y1 = [op(x1[5 * i:5 * i + 5, :]) for i in range(4)]
y1 = torch.stack(y1, dim=0).reshape(10, 20)
# for some unknown reason y is not always exactly same as y1 for some
# op (e.g. torch.sigmoid). So we cannot use (y==y1).all()
return ((y - y1).abs() < 1e-6).all()
@alf.configurable
def _numerical_calculate_gain(nonlinearity, dz=0.01, r=5.0):
"""Compute the gain in a numerical way by integration. Assume :math:`y` is
the output, :math:`w` is the weight (mean=0, std=1), and :math:`x` is the
input, then
.. math::
Var(y) = Var(w) * E(x^2)
So we need to approximate :math:`E(x^2)` numerically.
Args:
nonlinearity (Callable): any callable activation function
dz (float): :math:`dz` in the integration
r (float): :math:`z` range will be :math:`[-r, r]`
Returns:
float: a gain factor that will be applied to the init weights.
"""
if not _is_elementwise_op(nonlinearity):
logging.warning(
"It seems that nonlinearity (%s) is not an elementwise operation."
"Calculating the gain of non-elementwise op is not supported. "
"Will use 1 as its gain" % str(nonlinearity))
return 1.
dist = torch.distributions.normal.Normal(0, 1)
z = torch.arange(-r, r, dz)
# `nonlinearity` might be an inplace op, need to use `z` before applying
# `nonlinearity` to `z`
prob = torch.exp(dist.log_prob(z))
x = nonlinearity(z)
Ex2 = (prob * x**2).sum() * dz
return torch.sqrt(1.0 / Ex2).cpu().numpy()
def _calculate_gain(nonlinearity, nonlinearity_param=0.01):
"""Deprecated: now use ``_numerical_calculate_gain`` instead.
Args:
nonlinearity (str): the name of the activation function
nonlinearity_param (float): additional parameter of the nonlinearity;
currently only used by ``leaky_relu`` as the negative slope (pytorch
default 0.01)
"""
if nonlinearity == "elu":
# ELU paper: "The weights have been initialized according to (He et al.,
# 2015)". Also there is another suggestion for math.sqrt(1.55) in:
# https://stats.stackexchange.com/questions/229885/whats-the-recommended-weight-initialization-strategy-when-using-the-elu-activat
return math.sqrt(1.55)
elif nonlinearity == "sigmoid":
# pytorch's init.calculate_gain has 1.0 for sigmoid, which is obviously
# wrong!
return math.sqrt(3.41)
else:
return nn.init.calculate_gain(nonlinearity, nonlinearity_param)
[docs]@alf.configurable
def variance_scaling_init(tensor,
gain=1.0,
mode="fan_in",
distribution="truncated_normal",
calc_gain_after_activation=True,
nonlinearity=math_ops.identity,
transposed=False):
"""Implements TensorFlow's `VarianceScaling` initializer.
`<https://github.com/tensorflow/tensorflow/blob/e5bf8de410005de06a7ff5393fafdf832ef1d4ad/tensorflow/python/ops/init_ops.py#L437>`_
A potential benefit of this intializer is that we can sample from a truncated
normal distribution: ``scipy.stats.truncnorm(a=-2, b=2, loc=0., scale=1.)``.
Also incorporates PyTorch's calculation of the recommended gains that taking
nonlinear activations into account, so that after N layers, the final output
std (in linear space) will be a constant regardless of N's value (when N is
large). This auto gain probably won't make much of a difference if the
network is shallow, as in most RL cases.
Example usage:
.. code-block:: python
from alf.networks.initializers import variance_scaling_init
layer = nn.Linear(2, 2)
variance_scaling_init(layer.weight.data,
nonlinearity=nn.functional.leaky_relu)
nn.init.zeros_(layer.bias.data)
Args:
tensor (torch.Tensor): the weights to be initialized
gain (float): a positive scaling factor for weight std. Different from
tf's implementation, this number is applied outside of ``math.sqrt``.
Note that if ``calc_gain_after_activation=True``, this number will be
an additional gain factor on top of that.
mode (str): one of "fan_in", "fan_out", and "fan_avg"
distribution (str): one of "uniform", "untruncated_normal" and
"truncated_normal". If the latter, the weights will be sampled
from a normal distribution truncated at ``(-2, 2)``.
calc_gain_after_activation (bool): whether automatically calculate the
std gain of applying nonlinearity after this layer. A nonlinear
activation (e.g., relu) might change std after the transformation,
so we need to compensate for that. Only used when mode=="fan_in".
nonlinearity (Callable): any callable activation function
transposed (bool): a flag indicating if the weight tensor has been
tranposed (e.g., ``nn.ConvTranspose2d``). In that case, `fan_in` and
`fan_out` should be swapped.
Returns:
torch.Tensor: a randomly initialized weight tensor
"""
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor)
if transposed:
fan_in, fan_out = fan_out, fan_in
assert mode in ["fan_in", "fan_out", "fan_avg"], \
"Unrecognized mode %s!" % mode
if mode == "fan_in":
size = max(1.0, fan_in)
elif mode == "fan_out":
size = max(1.0, fan_out)
else:
size = max(1.0, (fan_in + fan_out) / 2.0)
if (calc_gain_after_activation and mode == "fan_in"):
gain *= _numerical_calculate_gain(nonlinearity)
std = gain / math.sqrt(size)
if distribution == "truncated_normal":
scale = 0.87962566 # scipy.stats.truncnorm.std(-2.0, 2.0)
std /= scale
nn.init.trunc_normal_(tensor, a=-2.0, b=2.0) # truncate within 2 std
return tensor.mul_(std)
elif distribution == "uniform":
limit = math.sqrt(3.0) * std
with torch.no_grad():
return tensor.uniform_(-limit, limit)
elif distribution == "untruncated_normal":
with torch.no_grad():
return tensor.normal_(0, std)
else:
raise ValueError("Invalid `distribution` argument:", distribution)