# Copyright (c) 2021 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from torch import Tensor
import torch.distributions as td
[docs]class NormalITS(InverseTransformSampling):
"""Normal distribution.
.. math::
p(x) = 1/sqrt(2*pi) * exp(-x^2/2)
"""
[docs] @staticmethod
def cdf(x):
# sqrt(0.5) = 0.7071067811865476
return 0.5 * (1 + torch.erf(x * 0.7071067811865476))
[docs] @staticmethod
def icdf(x):
# sqrt(2) = 1.4142135623730951
return torch.erfinv(2 * x - 1) * 1.4142135623730951
[docs] @staticmethod
def log_prob(x):
# log(sqrt(2 * pi)) = 0.9189385332046727
return -0.5 * (x**2) - 0.9189385332046727
[docs]class CauchyITS(InverseTransformSampling):
"""Cauchy distribution.
.. math::
p(x) = 1 / (pi * (1 + x*x))
"""
[docs] @staticmethod
def cdf(x):
return torch.atan(x) / math.pi + 0.5
[docs] @staticmethod
def icdf(x):
return torch.tan(math.pi * (x - 0.5))
[docs] @staticmethod
def log_prob(x):
return -(math.pi * (1 + x**2)).log()
[docs]class T2Cdf_(torch.autograd.Function):
[docs] @staticmethod
def forward(ctx, x):
# return 0.5 + 0.5 * x / (1 + x**2).sqrt()
# The original form is not numerically increasing (i.e. for some large
# value of x2>x1, f(x2) < f(x1) because of insufficient numerical
# precision), so we use the following form. And we need to implement our
# own backward because autograd cannot calculate the gradient of the
# following form when x is 0.
y = 0.5 + 0.5 * x.sign() * (1 + x**-2).rsqrt()
ctx.save_for_backward(x)
return y
[docs] @staticmethod
def backward(ctx, grad_output):
x, = ctx.saved_tensors
return 0.5 * (1 + x**2)**(-1.5) * grad_output
t2cdf = T2Cdf_.apply
[docs]class T2ITS(InverseTransformSampling):
"""Student's t-distribution with DOF 2.
.. math::
p(x) = 1 / (2 * (1 + x*x) ** 1.5)
"""
[docs] @staticmethod
def cdf(x):
return t2cdf(x)
[docs] @staticmethod
def icdf(x):
return (x - 0.5) * (x * (1 - x)).rsqrt()
[docs] @staticmethod
def log_prob(x):
# log(2) = 0.6931471805599453
return -1.5 * (1 + x**2).log() - 0.6931471805599453
[docs]class TruncatedDistribution(td.Distribution):
r"""The base class of truncated distributions.
A truncated distribution :math:`q(x)` is defined as a standard base distribution :math:`p(x)` and
location :math:`\mu`, scale parameters :math:`s`, lower bound :math:`l` and
upper bound :math:`u`
.. math::
q(x) = \frac{1}{s (P(u) - P(l))}p(\frac{x-\mu}{s}) if l \le x le u
q(x) = 0 otherwise
where :math:`P` is the cdf of :math:`p`.
Args:
loc: the location parameter. Its shape is batch_shape + event_shape.
scale: the scale parameter. Its shape is batch_shape + event_shape.
lower_bound: the lower bound. Its shape is event_shape.
upper_bound: the upper bound. Its shape is event_shape.
its: the standard distribution to be used.
"""
arg_constraints = {
'loc': td.constraints.real,
'scale': td.constraints.positive
}
has_rsample = True
def __init__(self, loc: Tensor, scale: Tensor, lower_bound: Tensor,
upper_bound: Tensor, its: InverseTransformSampling):
event_shape = torch.broadcast_shapes(lower_bound.shape,
upper_bound.shape)
batch_shape = torch.broadcast_shapes(scale.shape, loc.shape,
event_shape)
if len(event_shape) > 0:
batch_shape = batch_shape[:-len(event_shape)]
self._scale = scale
self._loc = loc
super().__init__(batch_shape=batch_shape, event_shape=event_shape)
self._its = its
self._lower_bound = lower_bound.to(loc.device)
self._upper_bound = upper_bound.to(loc.device)
self._cdf_lb = self._its.cdf((self._lower_bound - loc) / scale)
self._cdf_ub = self._its.cdf((self._upper_bound - loc) / scale)
self._logz = (scale * (self._cdf_ub - self._cdf_lb + 1e-30)).log()
@property
def scale(self):
"""Scale parameter of this distribution."""
return self._scale
@property
def loc(self):
"""Location parameter of this distribution."""
return self._loc
@property
def lower_bound(self):
"""Lower bound of this distribution."""
return self._lower_bound
@property
def upper_bound(self):
"""Upper bound of this distribution."""
return self._upper_bound
@property
def mode(self):
"""Mode of this distribution."""
result = torch.maximum(self._lower_bound, self._loc)
result = torch.minimum(self._upper_bound, result)
return result
[docs] def rsample(self, sample_shape: torch.Size = torch.Size()):
"""
Generates a sample_shape shaped reparameterized sample or sample_shape
shaped batch of reparameterized samples if the distribution parameters
are batched.
Args:
sample_shape: sample shape
Returns:
Tensor of shape ``sample_shape + batch_shape + event_shape``
"""
r = torch.rand(sample_shape + self._batch_shape + self._event_shape)
r = (1 - r) * self._cdf_lb + r * self._cdf_ub
r = r.clamp(0.001, 0.999)
x = self._its.icdf(r) * self._scale + self._loc
assert torch.isfinite(x).all()
# because of the r.clamp() above, x may be out of bound
x = torch.maximum(x, self._lower_bound)
x = torch.minimum(x, self._upper_bound)
return x
[docs] def log_prob(self, value: Tensor):
"""The log of the probability density evaluated at ``value``.
Args:
value: its shape should be ``sample_shape + batch_shape + event_shape``
Returns:
Tensor of shape ``sample_shape + batch_shape``
"""
y = self._its.log_prob((value - self._loc) / self._scale) - self._logz
assert torch.isfinite(y).all()
n = len(self._event_shape)
if n > 0:
return y.sum(dim=list(range(-n, 0)))
else:
return y
[docs]class TruncatedNormal(TruncatedDistribution):
r"""Truncated normal distribution.
The truncated normal distribution :math:`q(x)` is defined by 4 parameters:
location :math:`\mu`, scale parameters :math:`s`, lower bound :math:`l` and
upper bound :math:`u`.
.. math::
q(x) = \frac{1}{s (P(u) - P(l))}p(\frac{x-\mu}{s})
where :math:`p` and :math:`P` are the pdf and cdf of the standard normal
distribution respectively.
Args:
loc: the location parameter
scale: the scale parameter
lower_bound: the lower bound
upper_bound: the upper bound
its: the standard distribution to be used.
"""
def __init__(self, loc, scale, lower_bound, upper_bound):
super().__init__(loc, scale, lower_bound, upper_bound, NormalITS())
@td.kl.register_kl(TruncatedNormal, TruncatedNormal)
def _kl_truncated_normal_trucated_normal(p, q):
"""Registered KL Divergence computation specialized for TruncatedNormal
It is closed form w.r.t. torch.erf.
"""
assert torch.all(
torch.logical_and(
torch.isclose(p.lower_bound, q.lower_bound),
torch.isclose(p.upper_bound, q.upper_bound)))
delta = p.loc - q.loc
delta2 = delta**2
sigma_p2 = p.scale**2
# Pad sigma_q2 as it is positive will only be served as denominator
sigma_q2 = q.scale**2 + 1e-30
c1 = 0.5 * (torch.log(q.scale) - torch.log(p.scale)) + 0.25 * (
delta2 + sigma_p2) / sigma_q2 - 0.25
# 1 / sqrt(2 pi) = 0.3989422804014327
c2 = -0.3989422804014327 * p.scale * delta / sigma_q2
# 0.5 / sqrt(pi) = 0.28209479177387814
c3 = (1.0 - sigma_p2 / sigma_q2) * 0.28209479177387814
# sqrt(0.5) = 0.7071067811865475
t_u = (p.upper_bound - p.loc) * 0.7071067811865475 / (p.scale + 1e-30)
t_l = (p.lower_bound - p.loc) * 0.7071067811865475 / (p.scale + 1e-30)
upper = c1 * torch.erf(t_u) + (c3 * t_u + c2) * torch.exp(-t_u * t_u)
lower = c1 * torch.erf(t_l) + (c3 * t_l + c2) * torch.exp(-t_l * t_l)
# At this moment, before_normalization holds the integral part of
# original gaussian (but both p and q are not normalized by area_p
# and area_q respectively). This will be handled below.
before_normalization = upper - lower
# Pad area_p as it is positive will only be served as denominator
area_p = p._cdf_ub - p._cdf_lb + 1e-30
area_q = q._cdf_ub - q._cdf_lb
return (torch.log(area_q / area_p) + before_normalization / area_p).sum(
dim=list(range(-len(p._event_shape), 0)))
[docs]class TruncatedCauchy(TruncatedDistribution):
r"""Truncated Cauchy distribution.
The truncated normal distribution :math:`q(x)` is defined by 4 parameters:
location :math:`\mu`, scale parameters :math:`s`, lower bound :math:`l` and
upper bound :math:`u`.
.. math::
q(x) = \frac{1}{s (P(u) - P(l))}p(\frac{x-\mu}{s})
where :math:`p` and :math:`P` are the pdf and cdf of the standard Cauchy
distribution respectively.
Args:
loc: the location parameter
scale: the scale parameter
lower_bound: the lower bound
upper_bound: the upper bound
its: the standard distribution to be used.
"""
def __init__(self, loc, scale, lower_bound, upper_bound):
super().__init__(loc, scale, lower_bound, upper_bound, CauchyITS())
[docs]class TruncatedT2(TruncatedDistribution):
r"""Truncated Student's t distribution with degree of freedom 2.
The truncated normal distribution :math:`q(x)` is defined by 4 parameters:
location :math:`\mu`, scale parameters :math:`s`, lower bound :math:`l` and
upper bound :math:`u`.
.. math::
q(x) = \frac{1}{s (P(u) - P(l))}p(\frac{x-\mu}{s})
where :math:`p(x)=1 / (2 * (1 + x^2)^1.5)` and :math:`P` is the cdf of
:math:`p(x)`.
Args:
loc: the location parameter
scale: the scale parameter
lower_bound: the lower bound
upper_bound: the upper bound
its: the standard distribution to be used.
"""
def __init__(self, loc, scale, lower_bound, upper_bound):
super().__init__(loc, scale, lower_bound, upper_bound, T2ITS())