# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import numbers
import numpy as np
import math
import torch
import torch.distributions as td
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
import torch.nn as nn
from typing import Union
import alf
import alf.nest as nest
from alf.tensor_specs import TensorSpec, BoundedTensorSpec
from .distributions import TruncatedDistribution, TruncatedNormal, TruncatedCauchy, TruncatedT2
[docs]def get_invertible(cls):
"""A helper function to turn on the cache mechanism for transformation.
This is useful as some transformations (say :math:`g`) may not be able to
provide an accurate inversion therefore the difference between :math:`x` and
:math:`g^{-1}(g(x))` is large. This could lead to unstable training in
practice. For a torch transformation :math:`y=g(x)`, when ``cache_size`` is
set to one, the latest value for :math:`(x, y)` is cached and will be used
later for future computations. E.g. for inversion, a call to
:math:`g^{-1}(y)` will return :math:`x`, solving the inversion error issue
mentioned above. Note that in the case of having a chain of transformations
(:math:`G`), all the element transformations need to turn on the cache to
ensure the composite transformation :math:`G` satisfy:
:math:`x=G^{-1}(G(x))`.
"""
class NewCls(cls):
__init__ = functools.partialmethod(cls.__init__, cache_size=1)
return NewCls
"""
WARNING: If you need to train policy gradient with a ``TransformedDistribution``,
then make sure to detach the sampled action when the transforms have trainable
parameters.
For detailed reasons, please refer to ``alf/docs/notes/compute_probs_of_transformed_dist.rst``.
"""
AbsTransform = get_invertible(td.AbsTransform)
ExpTransform = get_invertible(td.ExpTransform)
PowerTransform = get_invertible(td.PowerTransform)
SigmoidTransform = get_invertible(td.SigmoidTransform)
SoftmaxTransform = get_invertible(td.SoftmaxTransform)
[docs]@alf.configurable
class Softplus(td.Transform):
r"""Transform via the mapping :math:`\text{Softplus}(x) = \log(1 + \exp(x))`.
Code adapted from `pyro <https://docs.pyro.ai/en/latest/_modules/pyro/distributions/transforms/softplus.html>`_
and `tensorflow <https://github.com/tensorflow/probability/blob/v0.12.2/tensorflow_probability/python/bijectors/softplus.py#L61-L189>`_.
"""
domain = constraints.real
codomain = constraints.positive
bijective = True
sign = +1
def __init__(self, hinge_softness=1., cache_size=1):
"""
Args:
hinge_softness (float): this positive parameter changes the transition
slope. A higher softness results in a smoother transition from
0 to identity.
"""
super().__init__(cache_size=cache_size)
self._hinge_softness = float(hinge_softness)
assert self._hinge_softness > 0, "Must be a positive softness number!"
def __eq__(self, other):
return (isinstance(other, Softplus)
and self._hinge_softness == other._hinge_softness)
def _call(self, x):
return nn.functional.softplus(x, beta=1. / self._hinge_softness)
def _inverse(self, y):
return (y / self._hinge_softness).expm1().log() * self._hinge_softness
[docs] def log_abs_det_jacobian(self, x, y):
return -nn.functional.softplus(-x / self._hinge_softness)
[docs] def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
return Softplus(self._hinge_softness, cache_size)
[docs]@alf.configurable
def Softlower(low, hinge_softness=1.):
"""Create a Softlower transform by composing the Softplus and Affine
transforms. Mathematically, ``softlower(x, low) = softplus(x - low) + low``.
Args:
low (float|Tensor): the lower bound
hinge_softness (float): this positive parameter changes the transition
slope. A higher softness results in a smoother transition from
``low`` to identity.
"""
return td.transforms.ComposeTransform([
AffineTransform(loc=-low, scale=1.),
Softplus(hinge_softness=hinge_softness),
AffineTransform(loc=low, scale=1.)
])
[docs]@alf.configurable
def Softupper(high, hinge_softness=1.):
"""Create a Softupper transform by composing the Softplus and Affine
transforms. Mathematically, ``softupper(x, high) = -softplus(high - x) + high``.
Args:
high (float|Tensor): the upper bound
hinge_softness (float): this positive parameter changes the transition
slope. A higher softness results in a smoother transition from
identity to ``high``.
"""
return td.transforms.ComposeTransform([
AffineTransform(loc=high, scale=-1.),
Softplus(hinge_softness=hinge_softness),
AffineTransform(loc=high, scale=-1.)
])
[docs]@alf.configurable
def SoftclipTF(low, high, hinge_softness=1.):
"""Create a Softclip transform by composing Softlower, Softupper, and Affine
transforms, adapted from `tensorflow <https://www.tensorflow.org/probability/api_docs/python/tfp/bijectors/SoftClip>`_.
Mathematically,
.. code-block:: python
clipped = softupper(softlower(x, low), high)
softclip(x) = (clipped - high) / (high - softupper(low, high)) * (high - low) + high
The second scaling step is beause we will have
``softupper(low, high) < low`` due to distortion of softplus, so we need to
shrink the interval slightly by ``(high - low) / (high - softupper(low, high))``
to preserve the lower bound. Due to this rescaling, the bijector can be mildly
asymmetric.
Args:
low (float|Tensor): the lower bound
high (float|Tensor): the upper bound
hinge_softness (float): this positive parameter changes the transition
slope. A higher softness results in a smoother transition from
``low`` to ``high``.
"""
if not isinstance(low, torch.Tensor):
low = torch.tensor(low)
assert torch.all(high > low), "Invalid clipping range"
# Compute the clipped value of ``low`` upper bounded by ``high``
softupper_high_at_low = Softupper(high, hinge_softness=hinge_softness)(low)
return td.transforms.ComposeTransform([
Softlower(low=low, hinge_softness=hinge_softness),
Softupper(high=high, hinge_softness=hinge_softness), # clipped
AffineTransform(loc=-high, scale=1.),
AffineTransform(
loc=high, scale=(high - low) / (high - softupper_high_at_low))
])
[docs]@alf.configurable
class Softclip(td.Transform):
r"""Transform via the mapping defined in ``alf.math_ops.softclip()``.
Unlike ``SoftclipTF``, this transform is symmetric regarding the lower and
upper bound when squashing.
"""
domain = constraints.real
codomain = constraints.real
bijective = True
sign = +1
def __init__(self, low, high, hinge_softness=1., cache_size=1):
"""
Args:
low (float): the lower bound
high (float): the upper bound
hinge_softness (float): this positive parameter changes the transition
slope. A higher softness results in a smoother transition from
``low`` to ``high``.
"""
super().__init__(cache_size=cache_size)
self._hinge_softness = float(hinge_softness)
assert self._hinge_softness > 0, "Must be a positive softness number!"
self._l = float(low)
self._h = float(high)
self.codomain = constraints.interval(self._l, self._h)
def __eq__(self, other):
return (isinstance(other, Softclip)
and self._hinge_softness == other._hinge_softness
and self._l == other._l and self._h == other._h)
[docs] def get_builder(self):
return functools.partial(Softclip, low=self._l, high=self._h)
def _call(self, x):
return alf.math.softclip(x, self._l, self._h, self._hinge_softness)
def _inverse(self, y):
"""``y`` should be in ``[self._l, self._h]``. Note that when ``y`` is
close to boundaries, this inverse function might have numerical issues.
Since we use ``cache_size=1`` in the init function, here we don't clip
``y``.
"""
s = self._hinge_softness
return (y + s * (((self._l - y) / s).expm1() / (
(y - self._h) / s).expm1()).log())
[docs] def log_abs_det_jacobian(self, x, y):
r"""Compute ``log|dy/dx|``.
"""
s = self._hinge_softness
return (1 - 1 / (1 + ((x - self._l) / s).exp()) - 1 / (1 + (
(self._h - x) / s).exp())).log()
[docs] def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
return Softclip(self._l, self._h, self._hinge_softness, cache_size)
[docs]@alf.configurable
class Softsign(td.Transform):
domain = constraints.real
codomain = constraints.interval(-1.0, 1.0)
bijective = True
sign = +1
def __init__(self, cache_size=1):
super().__init__(cache_size=cache_size)
def __eq__(self, other):
return isinstance(other, Softsign)
def _call(self, x):
return alf.math.softsign(x)
def _inverse(self, y):
r"""
.. math::
\begin{array}{lll}
y = \frac{x}{1+x} \rightarrow x = \frac{y}{1 - y}, &\text{if} &y > 0\\
y = \frac{x}{1-x} \rightarrow x = \frac{y}{1 + y}, &\text{else}&\\
\end{array}
"""
return torch.where(y > 0, y / (1 - y), y / (1 + y))
[docs] def log_abs_det_jacobian(self, x, y):
r"""
.. math::
\begin{array}{lll}
y = \frac{x}{1+x} \rightarrow \frac{dy}{dx} = \frac{1}{(1+x)^2}, &\text{if} &x > 0\\
y = \frac{x}{1-x} \rightarrow \frac{dy}{dx} = \frac{1}{(1-x)^2}, &\text{else}&\\
\end{array}
"""
return -2. * torch.log(1 + x.abs())
[docs] def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
return Softsign(cache_size)
[docs]@alf.configurable
class StableTanh(td.Transform):
r"""Invertible transformation (bijector) that computes :math:`Y = tanh(X)`,
therefore :math:`Y \in (-1, 1)`.
This can be achieved by an affine transform of the Sigmoid transformation,
i.e., it is equivalent to applying a list of transformations sequentially:
.. code-block:: python
transforms = [AffineTransform(loc=0, scale=2)
SigmoidTransform(),
AffineTransform(
loc=-1,
scale=2]
However, using the ``StableTanh`` transformation directly is more numerically
stable.
"""
domain = constraints.real
codomain = constraints.interval(-1.0, 1.0)
bijective = True
sign = +1
def __init__(self, cache_size=1):
# We use cache by default as it is numerically unstable for inversion
super().__init__(cache_size=cache_size)
def __eq__(self, other):
return isinstance(other, StableTanh)
def _call(self, x):
return torch.tanh(x)
def _inverse(self, y):
# Based on https://github.com/tensorflow/agents/commit/dfb8c85a01d65832b05315928c010336df13f7b9#diff-a572e559b953f965c5c2cd1b9ded2c7b
# 0.99999997 is the maximum value such that atanh(x) is valid for both
# float32 and float64
def _atanh(x):
return 0.5 * torch.log((1 + x) / (1 - x))
y = torch.where(
torch.abs(y) <= 1.0, torch.clamp(y, -0.99999997, 0.99999997), y)
return _atanh(y)
[docs] def log_abs_det_jacobian(self, x, y):
return 2.0 * (
torch.log(torch.tensor(2.0, dtype=x.dtype, requires_grad=False)) -
x - nn.functional.softplus(-2.0 * x))
[docs] def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
return StableTanh(cache_size)
[docs]class OUProcess(nn.Module):
"""A zero-mean Ornstein-Uhlenbeck process for generating noises."""
def __init__(self, initial_value, damping=0.15, stddev=0.2):
"""
The Ornstein-Uhlenbeck process is a process that generates temporally
correlated noise via a random walk with damping. This process describes
the velocity of a particle undergoing brownian motion in the presence of
friction. This can be useful for exploration in continuous action
environments with momentum.
The temporal update equation is:
.. code-block:: python
x_next = (1 - damping) * x + N(0, std_dev)
Args:
initial_value (Tensor): Initial value of the process.
damping (float): The rate at which the noise trajectory is damped
towards the mean. We must have :math:`0 <= damping <= 1`, where
a value of 0 gives an undamped random walk and a value of 1 gives
uncorrelated Gaussian noise. Hence in most applications a small
non-zero value is appropriate.
stddev (float): Standard deviation of the Gaussian component.
"""
super(OUProcess, self).__init__()
self._damping = damping
self._stddev = stddev
self._x = initial_value.clone().detach()
[docs] def forward(self):
noise = torch.randn_like(self._x) * self._stddev
return self._x.data.copy_((1 - self._damping) * self._x + noise)
[docs]class DiagMultivariateNormal(td.Independent):
def __init__(self, loc, scale):
"""Create multivariate normal distribution with diagonal variance.
Args:
loc (Tensor): mean of the distribution
scale (Tensor): standard deviation. Should have same shape as ``loc``.
"""
# set validate_args to False here to enable the construction of Normal
# distribution with zero scale.
super().__init__(
td.Normal(loc, scale, validate_args=False),
reinterpreted_batch_ndims=1)
@property
def stddev(self):
return self.base_dist.stddev
[docs]@alf.configurable(whitelist=['eps'])
class Beta(td.Beta):
r"""Beta distribution parameterized by ``concentration1`` and ``concentration0``.
Note: we need to wrap ``td.Beta`` so that ``self.concentration1`` and
``self.concentration0`` are the actual tensors passed in to construct the
distribution. This is important in certain situation. For example, if you want
to register a hook to process the gradient to ``concentration1`` and ``concentration0``,
``td.Beta.concentration0.register_hook()`` will not work because gradient will
not be backpropped to ``td.Beta.concentration0`` since it is sliced from
``td.Dirichlet.concentration`` and gradient will only be backpropped to
``td.Dirichlet.concentration`` instead of ``td.Beta.concentration0`` or
``td.Beta.concentration1``.
"""
def __init__(self,
concentration1,
concentration0,
eps=None,
validate_args=None):
"""
Args:
concentration1 (float or Tensor): 1st concentration parameter of the distribution
(often referred to as alpha)
concentration0 (float or Tensor): 2nd concentration parameter of the distribution
(often referred to as beta)
eps (float): a very small value indicating the interval ``[eps, 1-eps]``
into which the sampled values will be clipped. This clipping can
prevent ``NaN`` and ``Inf`` values in the gradients. If None,
a small value defined by PyTorch will be used.
"""
self._concentration1 = concentration1
self._concentration0 = concentration0
super().__init__(concentration1, concentration0, validate_args)
if eps is None:
self._eps = torch.finfo(self._dirichlet.concentration.dtype).eps
else:
self._eps = float(eps)
@property
def concentration0(self):
return self._concentration0
@property
def concentration1(self):
return self._concentration1
@property
def mode(self):
alpha = self.concentration1
beta = self.concentration0
mode = torch.where((alpha > 1) & (beta > 1),
(alpha - 1) / (alpha + beta - 2),
torch.where(alpha < beta, torch.zeros(()),
torch.ones(())))
return mode
[docs] def rsample(self, sample_shape=()):
"""We override the original ``rsample()`` in order to clamp the output
to avoid `NaN` and `Inf` values in the gradients. See Pyro's
``rsample()`` implementation in
`<https://docs.pyro.ai/en/dev/_modules/pyro/distributions/affine_beta.html#AffineBeta>`_.
"""
x = super(Beta, self).rsample(sample_shape)
return torch.clamp(x, min=self._eps, max=1 - self._eps)
[docs]class DiagMultivariateBeta(td.Independent):
def __init__(self, concentration1, concentration0):
"""Create multivariate independent beta distribution.
Args:
concentration1 (float or Tensor): 1st concentration parameter of the
distribution (often referred to as alpha)
concentration0 (float or Tensor): 2nd concentration parameter of the
distribution (often referred to as beta)
"""
super().__init__(
Beta(concentration1, concentration0), reinterpreted_batch_ndims=1)
[docs]class StableCauchy(td.Cauchy):
[docs] def rsample(self, sample_shape=torch.Size(), clipping_value=0.49):
r"""Overwrite Pytorch's Cauchy rsample for a more stable result. Basically
the sampled number is clipped to fall within a reasonable range.
For reference::
> np.tan(math.pi * -0.499)
-318.30883898554157
> np.tan(math.pi * -0.49)
-31.820515953773853
Args:
clipping_value (float): suppose eps is sampled from ``(-0.5,0.5)``.
It will be clipped to ``[-clipping_value, clipping_value]`` to
avoid values with huge magnitudes.
"""
shape = self._extended_shape(sample_shape)
eps = self.loc.new(shape).uniform_()
eps = torch.clamp(eps - 0.5, min=-clipping_value, max=clipping_value)
return torch.tan(eps * math.pi) * self.scale + self.loc
[docs]class DiagMultivariateCauchy(td.Independent):
def __init__(self, loc, scale):
"""Create multivariate cauchy distribution with diagonal scale matrix.
Args:
loc (Tensor): median of the distribution. Note that Cauchy doesn't
have a mean (divergent).
scale (Tensor): also known as "half width". Should have the same
shape as ``loc``.
"""
super().__init__(StableCauchy(loc, scale), reinterpreted_batch_ndims=1)
@property
def loc(self):
return self.base_dist.loc
@property
def scale(self):
return self.base_dist.scale
[docs]class OneHotCategoricalStraightThrough(td.OneHotCategoricalStraightThrough):
"""Provide an additional property ``mode`` with gradient enabled.
"""
@property
def mode(self):
mode = torch.nn.functional.one_hot(
torch.argmax(self.logits, -1), num_classes=self.logits.shape[-1])
return mode.to(self.logits) + self.probs - self.probs.detach()
[docs]@alf.configurable
class OneHotCategoricalGumbelSoftmax(td.OneHotCategorical):
r"""Create a reparameterizable ``td.OneHotCategorical`` distribution based on
the Gumbel-softmax gradient estimator from
::
Jang et al., "CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX", 2017.
"""
has_rsample = True
def __init__(self, hard_sample: bool = True, tau: float = 1., **kwargs):
"""
Args:
hard_sample: If False, the rsampled result will be a "soft" vector
of Gumbel softmax distribution, which naturally supports gradient
backprop. If True, ``argmax`` will be applied on top of it and then
a straight-through gradient estimator is used.
tau: the Gumbel-softmax temperature for ``rsample``. A higher
temperature leads to a more uniform sample.
"""
super(OneHotCategoricalGumbelSoftmax, self).__init__(**kwargs)
self._hard_sample = hard_sample
self._tau = tau
[docs] def rsample(self, sample_shape=torch.Size()):
sample_shape = torch.Size(sample_shape)
# expand additional first dims according to ``sample_shape``
shape = sample_shape + (1, ) * len(self.param_shape)
logits = self.logits.repeat(*shape)
return torch.nn.functional.gumbel_softmax(
logits=logits, tau=self._tau, hard=self._hard_sample, dim=-1)
@property
def mode(self):
mode = torch.nn.functional.one_hot(
torch.argmax(self.logits, -1), num_classes=self.logits.shape[-1])
return mode.to(self.logits) + self.probs - self.probs.detach()
def _builder_independent(base_builder, reinterpreted_batch_ndims_, **kwargs):
return td.Independent(base_builder(**kwargs), reinterpreted_batch_ndims_)
def _builder_transformed(base_builder, transform_builders, params_,
transforms_params_):
transforms = [
b(**p) for b, p in zip(transform_builders, transforms_params_)
]
return td.TransformedDistribution(base_builder(**params_), transforms)
def _get_categorical_builder(obj: Union[
td.Categorical, td.OneHotCategorical, td.
OneHotCategoricalStraightThrough, OneHotCategoricalStraightThrough]):
dist_cls = type(obj)
if 'probs' in obj.__dict__ and id(obj.probs) == id(obj._param):
# This means that obj is constructed using probs
return dist_cls, {'probs': obj.probs}
else:
return dist_cls, {'logits': obj.logits}
def _get_gumbelsoftmax_categorical_builder(
obj: OneHotCategoricalGumbelSoftmax):
builder = functools.partial(
OneHotCategoricalGumbelSoftmax,
hard_sample=obj._hard_sample,
tau=obj._tau)
if 'probs' in obj.__dict__ and id(obj.probs) == id(obj._param):
# This means that obj is constructed using probs
return builder, {'probs': obj.probs}
else:
return builder, {'logits': obj.logits}
def _get_independent_builder(obj: td.Independent):
builder, params = _get_builder(obj.base_dist)
new_builder = functools.partial(_builder_independent, builder,
obj.reinterpreted_batch_ndims)
return new_builder, params
def _get_transform_builders_params(transforms):
"""Return a nested structure where each node is a non-composed transform,
after expanding any composed transform in ``transforms``.
"""
def _get_transform_builder(transform):
if hasattr(transform, "get_builder"):
return transform.get_builder()
return transform.__class__
def _get_transform_params(transform):
if hasattr(transforms, 'params') and transforms.params is not None:
# We assume that if a td.Transform has attribute 'params', then they are the
# parameters we'll extract and store.
assert isinstance(
transforms.params,
dict), ("Transform params must be provided as a dict! "
f"Got {transforms.params}")
return transforms.params
return {} # the transform doesn't have any parameter
if isinstance(transforms, td.Transform):
if isinstance(transforms, td.ComposeTransform):
builders, params = _get_transform_builders_params(transforms.parts)
compose_transform_builder = lambda parts_params: td.ComposeTransform(
[b(**p) for b, p in zip(builders, parts_params)])
return compose_transform_builder, {'parts_params': params}
else:
builder = _get_transform_builder(transforms)
params = _get_transform_params(transforms)
return builder, params
assert isinstance(transforms, list), f"Incorrect transforms {transforms}!"
builders_and_params = [
_get_transform_builders_params(t) for t in transforms
]
builders, params = zip(*builders_and_params)
return list(builders), list(params)
def _get_transformed_builder(obj: td.TransformedDistribution):
# 'params' contains the dist params and all wrapped transform params starting
# 'obj.base_dist' downwards
builder, params = _get_builder(obj.base_dist)
transform_builders, transform_params = _get_transform_builders_params(
obj.transforms)
new_builder = functools.partial(_builder_transformed, builder,
transform_builders)
new_params = {"params_": params, 'transforms_params_': transform_params}
return new_builder, new_params
def _builder_affine_transformed(base_builder, loc_, scale_, **kwargs):
# 'loc' and 'scale' may conflict with the names in kwargs. So we add suffix '_'.
return AffineTransformedDistribution(base_builder(**kwargs), loc_, scale_)
def _get_affine_transformed_builder(obj: AffineTransformedDistribution):
builder, params = _get_builder(obj.base_dist)
new_builder = functools.partial(_builder_affine_transformed, builder,
obj.loc, obj.scale)
return new_builder, params
def _get_mixture_same_family_builder(obj: td.MixtureSameFamily):
mixture_builder, mixture_params = _get_builder(obj.mixture_distribution)
components_builder, components_params = _get_builder(
obj.component_distribution)
def _mixture_builder(mixture, components):
return td.MixtureSameFamily(
mixture_builder(**mixture), components_builder(**components))
return _mixture_builder, {
"mixture": mixture_params,
"components": components_params
}
_get_builder_map = {
td.Categorical:
_get_categorical_builder,
td.OneHotCategorical:
_get_categorical_builder,
td.OneHotCategoricalStraightThrough:
_get_categorical_builder,
OneHotCategoricalStraightThrough:
_get_categorical_builder,
OneHotCategoricalGumbelSoftmax:
_get_gumbelsoftmax_categorical_builder,
td.Normal:
lambda obj: (td.Normal, {
'loc': obj.mean,
'scale': obj.stddev
}),
StableCauchy:
lambda obj: (StableCauchy, {
'loc': obj.loc,
'scale': obj.scale
}),
td.Independent:
_get_independent_builder,
DiagMultivariateNormal:
lambda obj: (DiagMultivariateNormal, {
'loc': obj.mean,
'scale': obj.stddev
}),
DiagMultivariateCauchy:
lambda obj: (DiagMultivariateCauchy, {
'loc': obj.loc,
'scale': obj.scale
}),
td.TransformedDistribution:
_get_transformed_builder,
AffineTransformedDistribution:
_get_affine_transformed_builder,
Beta:
lambda obj: (Beta, {
'concentration1': obj.concentration1,
'concentration0': obj.concentration0
}),
DiagMultivariateBeta:
lambda obj: (DiagMultivariateBeta, {
'concentration1': obj.base_dist.concentration1,
'concentration0': obj.base_dist.concentration0
}),
TruncatedNormal:
lambda obj: (functools.partial(
TruncatedNormal,
lower_bound=obj.lower_bound,
upper_bound=obj.upper_bound), {
'loc': obj.loc,
'scale': obj.scale
}),
TruncatedCauchy:
lambda obj: (functools.partial(
TruncatedCauchy,
lower_bound=obj.lower_bound,
upper_bound=obj.upper_bound), {
'loc': obj.loc,
'scale': obj.scale
}),
TruncatedT2:
lambda obj: (functools.partial(
TruncatedT2,
lower_bound=obj.lower_bound,
upper_bound=obj.upper_bound), {
'loc': obj.loc,
'scale': obj.scale
}),
td.MixtureSameFamily:
_get_mixture_same_family_builder,
}
def _get_builder(obj):
return _get_builder_map[type(obj)](obj)
[docs]class DistributionSpec(object):
def __init__(self, builder, input_params_spec):
"""
Args:
builder (Callable): the function which is used to build the
distribution. The returned value of ``builder(input_params)``
is a ``Distribution`` with input parameter as ``input_params``.
input_params_spec (nested TensorSpec): the spec for the argument of
``builder``.
"""
self.builder = builder
self.input_params_spec = input_params_spec
[docs] def build_distribution(self, input_params):
"""Build a Distribution using ``input_params``.
Args:
input_params (nested Tensor): the parameters for build the
distribution. It should match ``input_params_spec`` provided as
``__init__``.
Returns:
Distribution:
"""
nest.assert_same_structure(input_params, self.input_params_spec)
return self.builder(**input_params)
[docs] @classmethod
def from_distribution(cls, dist, from_dim=0):
"""Create a ``DistributionSpec`` from a ``Distribution``.
Args:
dist (Distribution): the ``Distribution`` from which the spec is
extracted.
from_dim (int): only use the dimenions from this. The reason of
using ``from_dim>0`` is that ``[0, from_dim)`` might be batch
dimension in some scenario.
Returns:
DistributionSpec:
"""
builder, input_params = _get_builder(dist)
input_param_spec = extract_spec(input_params, from_dim)
return cls(builder, input_param_spec)
[docs]def to_distribution_param_spec(nests):
"""Convert the ``DistributionSpecs`` in nests to their parameter specs.
Args:
nests (nested DistributionSpec of TensorSpec): Each ``DistributionSpec``
will be converted to a dictionary of the spec of its input ``Tensor``
parameters.
Returns:
nested TensorSpec: Each leaf is a ``TensorSpec`` or a ``dict``
corresponding to one distribution, with keys as parameter name and
values as ``TensorSpecs`` for the parameters.
"""
def _to_param_spec(spec):
if isinstance(spec, DistributionSpec):
return spec.input_params_spec
elif isinstance(spec, TensorSpec):
return spec
else:
raise ValueError("Only TensorSpec or DistributionSpec is allowed "
"in nest, got %s. nest is %s" % (spec, nests))
return nest.map_structure(_to_param_spec, nests)
[docs]def params_to_distributions(nests, nest_spec):
"""Convert distribution parameters to ``Distribution``, keep tensors unchanged.
Args:
nests (nested Tensor): a nested ``Tensor`` and dictionary of tensor
parameters of ``Distribution``. Typically, ``nest`` is obtained using
``distributions_to_params()``.
nest_spec (nested DistributionSpec and TensorSpec): The distribution
params will be converted to ``Distribution`` according to the
corresponding ``DistributionSpec`` in ``nest_spec``.
Returns:
nested Distribution or Tensor:
"""
def _to_dist(spec, params):
if isinstance(spec, DistributionSpec):
return spec.build_distribution(params)
elif isinstance(spec, TensorSpec):
return params
else:
raise ValueError(
"Only DistributionSpec or TensorSpec is allowed "
"in nest_spec, got %s. nest_spec is %s" % (spec, nest_spec))
return nest.map_structure_up_to(nest_spec, _to_dist, nest_spec, nests)
[docs]def distributions_to_params(nests):
"""Convert distributions to its parameters, and keep tensors unchanged.
Only returns parameters that have ``Tensor`` values.
Args:
nests (nested Distribution and Tensor): Each ``Distribution`` will be
converted to dictionary of its ``Tensor`` parameters.
Returns:
nested Tensor/Distribution: Each leaf is a ``Tensor`` or a ``dict``
corresponding to one distribution, with keys as parameter name and
values as tensors containing parameter values.
"""
def _to_params(dist_or_tensor):
if isinstance(dist_or_tensor, td.Distribution):
return extract_distribution_parameters(dist_or_tensor)
elif isinstance(dist_or_tensor, torch.Tensor):
return dist_or_tensor
else:
raise ValueError(
"Only Tensor or Distribution is allowed in nest, ",
"got %s. nest is %s" % (dist_or_tensor, nests))
return nest.map_structure(_to_params, nests)
[docs]def compute_entropy(distributions):
"""Computes total entropy of nested distribution.
Args:
distributions (nested Distribution): A possibly batched tuple of
distributions.
Returns:
entropy
"""
def _compute_entropy(dist: td.Distribution):
entropy = dist.entropy()
return entropy
entropies = nest.map_structure(_compute_entropy, distributions)
total_entropies = sum(nest.flatten(entropies))
return total_entropies
[docs]def compute_log_probability(distributions, actions):
"""Computes log probability of actions given distribution.
Args:
distributions: A possibly batched tuple of distributions.
actions: A possibly batched action tuple.
Returns:
Tensor: the log probability summed over actions in the batch.
"""
def _compute_log_prob(single_distribution, single_action):
single_log_prob = single_distribution.log_prob(single_action)
return single_log_prob
nest.assert_same_structure(distributions, actions)
log_probs = nest.map_structure(_compute_log_prob, distributions, actions)
total_log_probs = sum(nest.flatten(log_probs))
return total_log_probs
[docs]def rsample_action_distribution(nested_distributions, return_log_prob=False):
"""Sample actions from distributions with reparameterization-based sampling.
It uses ``Distribution.rsample()`` to do the sampling to enable backpropagation.
Args:
nested_distributions (nested Distribution): action distributions.
return_log_prob (bool): whether to compute and return the log
probability of the sampled actions, in addition to the sampled
actions. In some cases, it is useful to compute the log probability
immediately after the actions are sampled, as some subsequent
operations might makes the cache mechanism (if turned on) invalid.
Some example scenarios include 1) additional sampling operation
applied on ``nested_distributions``, 2) some operations applied to
the actions sampled from ``nested_distributions`` (e.g., cloning).
This which could cause numerical issues if we want to compute the
log probability for actions sampled at an early stage,
especially for actions that are close to action bounds.
For more details on PyTorch Transform, its cache mechanism, and its
impacts on RL algorithms, please check
`<https://alf.readthedocs.io/en/latest/notes/pytorch_notes.html#transform-bijector>`_.
Returns:
- rsampled actions if return_log_prob is False
- rsampled actions and log_prob if return_log_prob is True
"""
assert all(nest.flatten(nest.map_structure(lambda d: d.has_rsample,
nested_distributions))), \
("all the distributions need to support rsample in order to enable "
"backpropagation")
sample = nest.map_structure(lambda d: d.rsample(), nested_distributions)
if return_log_prob:
log_prob = compute_log_probability(nested_distributions, sample)
return sample, log_prob
else:
return sample
[docs]def sample_action_distribution(nested_distributions, return_log_prob=False):
"""Sample actions from distributions with conventional sampling without
enabling backpropagation.
Args:
nested_distributions (nested Distribution): action distributions.
return_log_prob (bool): whether to compute and return the log
probability of the sampled actions, in addition to the sampled
actions. In some cases, it is useful to compute the log probability
immediately after the actions are sampled, as some subsequent
operations might makes the cache mechanism (if turned on) invalid.
Some example scenarios include 1) additional sampling operation
applied on ``nested_distributions``, 2) some operations applied to
the actions sampled from ``nested_distributions`` (e.g., cloning).
This which could cause numerical issues if we want to compute the
log probability for actions sampled at an early stage,
especially for actions that are close to action bounds.
For more details on PyTorch Transform, its cache mechanism, and its
impacts on RL algorithms, please check
`<https://alf.readthedocs.io/en/latest/notes/pytorch_notes.html#transform-bijector>`_.
Returns:
- sampled actions if return_log_prob is False
- sampled actions and log_prob if return_log_prob is True
"""
sample = nest.map_structure(lambda d: d.sample(), nested_distributions)
if return_log_prob:
log_prob = compute_log_probability(nested_distributions, sample)
return sample, log_prob
else:
return sample
[docs]def epsilon_greedy_sample(nested_distributions, eps=0.1):
"""Generate greedy sample that maximizes the probability.
Args:
nested_distributions (nested Distribution): distribution to sample from
eps (float): a floating value in :math:`[0,1]`, representing the chance of
action sampling instead of taking argmax. This can help prevent
a dead loop in some deterministic environment like `Breakout`.
Returns:
(nested) Tensor:
"""
def greedy_fn(dist):
# pytorch distribution has no 'mode' operation
greedy_action = get_mode(dist)
if eps == 0.0:
return greedy_action
sample_action = dist.sample()
greedy_mask = torch.rand(sample_action.shape[0]) > eps
sample_action[greedy_mask] = greedy_action[greedy_mask]
return sample_action
if eps >= 1.0:
return sample_action_distribution(nested_distributions)
else:
return nest.map_structure(greedy_fn, nested_distributions)
[docs]def get_mode(dist):
"""Get the mode of the distribution. Note that if ``dist`` is a transformed
distribution, the result may not be the actual mode of ``dist``.
Args:
dist (td.Distribution):
Returns:
The mode of the distribution. If ``dist`` is a transformed distribution,
the result is calculated by transforming the mode of its base
distribution and may not be the actual mode for ``dist``.
Raises:
NotImplementedError: if dist or its base distribution is not
``td.Categorical``, ``td.Normal``, ``td.Independent`` or
``td.TransformedDistribution``.
"""
if isinstance(dist, td.categorical.Categorical):
mode = torch.argmax(dist.logits, -1)
elif isinstance(
dist,
(OneHotCategoricalStraightThrough, OneHotCategoricalGumbelSoftmax)):
# Our version of one-hot st supports mode with grad
mode = dist.mode
elif isinstance(
dist, (td.OneHotCategorical, td.OneHotCategoricalStraightThrough)):
mode = torch.nn.functional.one_hot(
torch.argmax(dist.logits, -1), num_classes=dist.logits.shape[-1])
elif isinstance(dist, td.normal.Normal):
mode = dist.mean
elif isinstance(dist, td.MixtureSameFamily):
# Note that this just computes an approximate mode. We use an approximate
# approach to compute the mode, by using the mode of the component
# distribution that has the highest component probability.
# [B]
ind = get_mode(dist.mixture_distribution)
# [B, num_component, d]
component_mode = get_mode(dist.component_distribution)
mode = component_mode[torch.arange(component_mode.shape[0]), ind]
elif isinstance(dist, StableCauchy):
mode = dist.loc
elif isinstance(dist, td.Independent):
mode = get_mode(dist.base_dist)
elif isinstance(dist, td.TransformedDistribution):
base_mode = get_mode(dist.base_dist)
with torch.no_grad():
mode = base_mode
for transform in dist.transforms:
mode = transform(mode)
elif isinstance(dist, (Beta, TruncatedDistribution)):
return dist.mode
else:
raise NotImplementedError(
"Distribution type %s is not supported" % type(dist))
return mode
[docs]def get_rmode(dist):
"""Get the mode of the distribution that support backpropogation.
Note that if ``dist`` is a transformed
distribution, the result may not be the actual mode of ``dist``.
Args:
dist (td.Distribution):
Returns:
The mode of the distribution. If ``dist`` is a transformed distribution,
the result is calculated by transforming the mode of its base
distribution and may not be the actual mode for ``dist``.
Raises:
NotImplementedError: if dist or its base distribution is not
``td.Normal``, ``StableCauchy``, ``Beta``, ``TruncatedDistribution``,
``td.Independent`` or ``td.TransformedDistribution``.
"""
if isinstance(dist, td.normal.Normal):
mode = dist.mean
elif isinstance(dist, td.MixtureSameFamily):
# note that for the mixture distribution, there is no gradient back-propagation
# [B]
ind = get_mode(dist.mixture_distribution)
# [B, num_component, d]
component_mode = get_rmode(dist.component_distribution)
mode = component_mode[torch.arange(component_mode.shape[0]), ind]
elif isinstance(dist, StableCauchy):
mode = dist.loc
elif isinstance(dist, Beta) or isinstance(dist, TruncatedDistribution):
return dist.mode
elif isinstance(dist, td.Independent):
mode = get_rmode(dist.base_dist)
elif isinstance(dist, td.TransformedDistribution):
base_mode = get_rmode(dist.base_dist)
mode = base_mode
for transform in dist.transforms:
mode = transform(mode)
else:
raise NotImplementedError(
"Distribution type %s is not supported" % type(dist))
return mode
[docs]def get_base_dist(dist):
"""Get the base distribution.
Args:
dist (td.Distribution):
Returns:
The base distribution if dist is ``td.Independent`` or
``td.TransformedDistribution``, and ``dist`` if it is ``td.Normal``.
Raises:
NotImplementedError: if ``dist`` or its based distribution is not
``td.Normal``, ``td.Independent`` or ``td.TransformedDistribution``.
"""
if isinstance(dist, (td.Normal, td.Categorical, StableCauchy, Beta,
TruncatedDistribution)):
return dist
elif isinstance(dist, (td.Independent, td.TransformedDistribution)):
return get_base_dist(dist.base_dist)
else:
raise NotImplementedError(
"Distribution type %s is not supported" % type(dist))
[docs]@alf.configurable
def estimated_entropy(dist, num_samples=1, check_numerics=False):
r"""Estimate entropy by sampling.
Use sampling to calculate entropy. The unbiased estimator for entropy is
:math:`-\log(p(x))` where :math:`x` is an unbiased sample of :math:`p`.
However, the gradient of :math:`-\log(p(x))` is not an unbiased estimator
of the gradient of entropy. So we also calculate a value whose gradient is
an unbiased estimator of the gradient of entropy. See ``notes/subtleties_of_estimating_entropy.py``
for detail.
Args:
dist (torch.distributions.Distribution): concerned distribution
num_samples (int): number of random samples used for estimating entropy.
check_numerics (bool): If true, find NaN / Inf values. For debugging only.
Returns:
tuple:
- entropy
- entropy_for_gradient: for calculating gradient.
"""
sample_shape = (num_samples, )
if dist.has_rsample:
single_action = dist.rsample(sample_shape=sample_shape)
else:
single_action = dist.sample(sample_shape=sample_shape)
if single_action.dtype.is_floating_point and dist.has_rsample:
entropy = -dist.log_prob(single_action)
if check_numerics:
assert torch.all(torch.isfinite(entropy))
entropy = entropy.mean(dim=0)
entropy_for_gradient = entropy
else:
entropy = -dist.log_prob(single_action.detach())
if check_numerics:
assert torch.all(torch.isfinite(entropy))
entropy_for_gradient = -0.5 * entropy**2
entropy = entropy.mean(dim=0)
entropy_for_gradient = entropy_for_gradient.mean(dim=0)
return entropy, entropy_for_gradient
# NOTE(hnyu): It might be possible to get a closed-form of entropy given a
# Normal as the base dist with only affine transformation?
# It's better (lower variance) than this estimated one.
#
# Something like what TFP does:
# https://github.com/tensorflow/probability/blob/356cfddef026b3339b8f2a81e600acd2ff8e22b4/tensorflow_probability/python/distributions/transformed_distribution.py#L636
# (Probably it's complicated, but we need to spend time figuring out if the
# current estimation is the best way to do this).
# Here, we compute entropy of transformed distributions using sampling.
[docs]def entropy_with_fallback(distributions, return_sum=True):
r"""Computes total entropy of nested distribution.
If ``entropy()`` of a distribution is not implemented, this function will
fallback to use sampling to calculate the entropy. It returns two values:
``(entropy, entropy_for_gradient)``.
There are two situations:
- ``entropy()`` is implemented and it's same as ``entropy_for_gradient``.
- ``entropy()`` is not implemented. We use sampling to calculate entropy. The
unbiased estimator for entropy is :math:`-\log(p(x))`. However, the gradient
of :math:`-\log(p(x))` is not an unbiased estimator of the gradient of
entropy. So we also calculate a value whose gradient is an unbiased
estimator of the gradient of entropy. See ``estimated_entropy()`` for detail.
Examples:
.. code-block:: python
ent, ent_for_grad = entropy_with_fall_back(dist, action_spec)
alf.summary.scalar("entropy", ent)
ent_for_grad.backward()
Args:
distributions (nested Distribution): A possibly batched tuple of
distributions.
return_sum (bool): if True, return the total entropy. If not True,
return the entropy for each distribution in the nest.
Returns:
tuple:
- entropy
- entropy_for_gradient: You should use ``entropy`` in situations where its
value is needed, and ``entropy_for_gradient`` where you need to calculate the
gradient of entropy.
"""
def _compute_entropy(dist: td.Distribution):
if isinstance(dist, AffineTransformedDistribution):
entropy, entropy_for_gradient = _compute_entropy(dist.base_dist)
entropy = entropy + dist._log_abs_scale
entropy_for_gradient = entropy_for_gradient + dist._log_abs_scale
elif isinstance(dist,
(td.TransformedDistribution, TruncatedDistribution)):
# TransformedDistribution is used by NormalProjectionNetwork with
# scale_distribution=True, in which case we estimate with sampling.
entropy, entropy_for_gradient = estimated_entropy(dist)
else:
entropy = dist.entropy()
entropy_for_gradient = entropy
return entropy, entropy_for_gradient
entropies = list(map(_compute_entropy, nest.flatten(distributions)))
entropies, entropies_for_gradient = zip(*entropies)
if return_sum:
return sum(entropies), sum(entropies_for_gradient)
else:
return (nest.pack_sequence_as(distributions, entropies),
nest.pack_sequence_as(distributions, entropies_for_gradient))
[docs]@alf.configurable
def calc_default_target_entropy(spec, min_prob=0.1):
"""Calculate default target entropy.
Args:
spec (TensorSpec): action spec
min_prob (float): If continuous spec, we suppose the prob concentrates on
a delta of ``min_prob * (M-m)``; if discrete spec, we uniformly
distribute ``min_prob`` on all entries except the peak which has
a probability of ``1 - min_prob``.
Returns:
target entropy
"""
def _calc_discrete_entropy(m, M, log_mp):
N = M - m + 1
if N == 1:
return 0
return (min_prob * (np.log(N - 1) - log_mp) -
(1 - min_prob) * np.log(1 - min_prob))
zeros = np.zeros(spec.shape)
min_max = np.broadcast(spec.minimum, spec.maximum, zeros)
cont = spec.is_continuous
log_mp = np.log(min_prob + 1e-30)
e = np.sum([(np.log(M - m) + log_mp if cont else _calc_discrete_entropy(
m, M, log_mp)) for m, M, _ in min_max])
return e
[docs]@alf.configurable
def calc_default_target_entropy_quantized(spec,
num_bins,
ent_per_action_dim=-1.0):
"""Calc default target entropy for quantized continuous action.
Args:
spec (TensorSpec): action spec
num_bins (int): number of quantization bins used to represent the
continuous action
ent_per_action_dim (int): desired entropy per action dimension
for the non-quantized continuous action; default value is -1.0
as suggested by the SAC paper.
Returns:
target entropy for quantized representation
"""
zeros = np.zeros(spec.shape)
min_max = np.broadcast(spec.minimum, spec.maximum, zeros)
cont = spec.is_continuous
assert cont, "only support continuous action-based computation"
log_Mn = np.log(spec.maximum - spec.minimum)
log_mp = ent_per_action_dim - log_Mn
log_B = np.log(num_bins)
ents = [log_mp + log_B for i in range(spec.shape[0])]
e = np.sum(ents)
assert e > 0, "wrong target entropy for discrete distribution {}".format(e)
return e
[docs]def calc_default_max_entropy(spec, fraction=0.8):
"""Calc default max entropy.
Args:
spec (TensorSpec): action spec
fraction (float): this fraction of the theoretical entropy upper bound
will be used as the max entropy
Returns:
A default max entropy for adjusting the entropy weight
"""
assert fraction <= 1.0 and fraction > 0
zeros = np.zeros(spec.shape)
min_max = np.broadcast(spec.minimum, spec.maximum, zeros)
cont = spec.is_continuous
# use uniform distributions to compute upper bounds
e = np.sum([(np.log(M - m) * (fraction if M - m > 1 else 1.0 / fraction)
if cont else np.log(M - m + 1) * fraction)
for m, M, _ in min_max])
return e