# Copyright (c) 2022 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import scipy.special
import torch
import torch.nn as nn
import torch.distributions as td
import alf
from alf.nest.utils import convert_device
def _gammaincinv(a, y):
# Inverse to the regularized lower incomplete gamma function.
# pytorch does not have a native implementation of gammaincinv, so we
# have to use scipy.
return convert_device(
torch.as_tensor(
scipy.special.gammaincinv(a.cpu().numpy(),
y.cpu().numpy()),
device='cpu'))
class _CategoricalSeedSamplerBase(alf.nn.Network):
# The reason of seperate _CategoricalSeedSamplerBase from CategoricalSeedSampler
# is for easier unittest.
def __init__(self, num_classes, new_noise_prob=0.01, concentration=1):
input_tensor_spec = alf.TensorSpec((num_classes, ))
super().__init__(
input_tensor_spec=input_tensor_spec, state_spec=input_tensor_spec)
self._concentration = concentration
self._new_noise_prob = new_noise_prob
def forward(self, input, state):
"""
Args:
input: categorical probabilities
"""
epsilon = state
batch_size = input.shape[0]
new_epsilon = torch.rand_like(input)
new_noise = torch.rand(batch_size) < self._new_noise_prob
# The initial state is always 0. So we need to generate new noise
# for initial state.
new_noise = new_noise | (epsilon == 0).all(dim=1)
new_noise = new_noise.unsqueeze(-1)
# epsilon follows Uniform(0,1)
epsilon = torch.where(new_noise, new_epsilon, epsilon)
alpha = self._concentration * input
# Use inverse transform sampling to obtain gamma samples.
# gamma follows Gamma distribution Gamma(alpha, 1)
gamma = _gammaincinv(alpha.clamp(min=1e-30), epsilon).clamp(min=1e-30)
# prob follows Dirichlet distribution Dirichlet(alpha)
# see https://en.wikipedia.org/wiki/Dirichlet_distribution#Related_distributions
prob = gamma / gamma.sum(dim=-1, keepdim=True)
return prob, epsilon
[docs]@alf.repr_wrapper
@alf.configurable
class CategoricalSeedSampler(_CategoricalSeedSamplerBase):
r"""Sample actions with temporal consistency.
In order to do so, we maintain an internal stateful noise vector :math:`\epsilon`
and use it to modify the original categorical distribution :math:`\pi` to a new
distribution :math:`\tilde{\pi}=f(\pi, \epsilon)`. The evolution of :math:`\epsilon`
and :math:`f` are chosen so that :math:`E(\tilde{\pi})=\pi`. More specifically,
:math:`f` is chosen so that :math:`\tilde{\pi}` follows Dirichlet distribution
:math:`Dir(c \pi)`.
Args:
num_classes: number of classes for the categorical distribution
new_noise_prob: the probability of generating a new :math:`\epsilon`
concentration: the concentration scaling factor c. Larger ``concentration``
tends to generate :math:`\tilde{\pi}` closer to :math:`\pi`.
"""
def __init__(self,
num_classes: int,
new_noise_prob: float = 0.01,
concentration: float = 1):
super().__init__(num_classes, new_noise_prob, concentration)
[docs] def forward(self, input: torch.Tensor, state: torch.Tensor):
"""
Args:
input: the parameter of the categorical distribution with the shape of
``[batch_size, num_classes]``
state: noise state (i.e. :math:`\epsilon`)
"""
prob, state = super().forward(input, state)
action_id = torch.multinomial(prob, num_samples=1).squeeze(1)
return action_id, state
[docs]@alf.repr_wrapper
class EpsilonGreedySampler(nn.Module):
"""Epsilon greedy sampler.
With probability ``1 - epsilon_greedy``, sample actions with the largest
probability. With probability ``epsilon_greedy``, sample actions according
to the given categorical distribution.
Args:
epsilon_greedy: see above.
"""
def __init__(self, epsilon_greedy=0.1):
super().__init__()
self._epsilon_greedy = epsilon_greedy
[docs] def forward(self, input):
"""
Args:
input: categorical probabilities with the shape of ``[batch_size, num_classes]``
"""
action_id = torch.multinomial(input, num_samples=1).squeeze(1)
if self._epsilon_greedy < 1:
greedy_action_id = input.argmax(dim=1)
if self._epsilon_greedy > 0:
r = torch.rand(action_id.shape) >= self._epsilon_greedy
action_id[r] = greedy_action_id[r]
else:
action_id = greedy_action_id
return action_id
[docs]@alf.repr_wrapper
class MultinomialSampler(nn.Module):
"""Sample actions according to the given multinomial distribution.
"""
def __init__(self):
super().__init__()
[docs] def forward(self, input):
"""
Args:
input: categorical probabilities with the shape of ``[batch_size, num_classes]``
"""
action_id = torch.multinomial(input, num_samples=1).squeeze(1)
return action_id