# Copyright (c) 2019 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from typing import Union, List, Callable
import alf
from alf.data_structures import LossInfo, namedtuple, StepType
from alf.utils.losses import element_wise_squared_loss
from alf.utils import losses, tensor_utils, value_ops
from alf.utils.summary_utils import safe_mean_hist_summary
from alf.utils.normalizers import AdaptiveNormalizer
[docs]@alf.configurable
class TDLoss(nn.Module):
"""Temporal difference loss."""
def __init__(self,
gamma: Union[float, List[float]] = 0.99,
td_error_loss_fn: Callable = element_wise_squared_loss,
td_lambda: float = 0.95,
normalize_target: bool = False,
debug_summaries: bool = False,
name: str = "TDLoss"):
r"""
Let :math:`G_{t:T}` be the bootstraped return from t to T:
.. math::
G_{t:T} = \sum_{i=t+1}^T \gamma^{t-i-1}R_i + \gamma^{T-t} V(s_T)
If ``td_lambda`` = 1, the target for step t is :math:`G_{t:T}`.
If ``td_lambda`` = 0, the target for step t is :math:`G_{t:t+1}`
If 0 < ``td_lambda`` < 1, the target for step t is the :math:`\lambda`-return:
.. math::
G_t^\lambda = (1 - \lambda) \sum_{i=t+1}^{T-1} \lambda^{i-t}G_{t:i} + \lambda^{T-t-1} G_{t:T}
There is a simple relationship between :math:`\lambda`-return and
the generalized advantage estimation :math:`\hat{A}^{GAE}_t`:
.. math::
G_t^\lambda = \hat{A}^{GAE}_t + V(s_t)
where the generalized advantage estimation is defined as:
.. math::
\hat{A}^{GAE}_t = \sum_{i=t}^{T-1}(\gamma\lambda)^{i-t}(R_{i+1} + \gamma V(s_{i+1}) - V(s_i))
References:
Schulman et al. `High-Dimensional Continuous Control Using Generalized Advantage Estimation
<https://arxiv.org/abs/1506.02438>`_
Sutton et al. `Reinforcement Learning: An Introduction
<http://incompleteideas.net/book/the-book.html>`_, Chapter 12, 2018
Args:
gamma: A discount factor for future rewards. For
multi-dim reward, this can also be a list of discounts, each
discount applies to a reward dim.
td_error_loss_fn: A function for computing the TD errors
loss. This function takes as input the target and the estimated
Q values and returns the loss for each element of the batch.
td_lambda: Lambda parameter for TD-lambda computation.
normalize_target (bool): whether to normalize target.
Note that the effect of this is to change the loss. The critic
value itself is not normalized.
debug_summaries: True if debug summaries should be created.
name: The name of this loss.
"""
super().__init__()
self._name = name
self._gamma = torch.tensor(gamma)
self._td_error_loss_fn = td_error_loss_fn
self._lambda = td_lambda
self._debug_summaries = debug_summaries
self._normalize_target = normalize_target
self._target_normalizer = None
@property
def gamma(self):
r"""Return the :math:`\gamma` value for discounting future rewards.
Returns:
Tensor: a rank-0 or rank-1 (multi-dim reward) floating tensor.
"""
return self._gamma.clone()
[docs] def compute_td_target(self, info: namedtuple, target_value: torch.Tensor):
"""Calculate the td target.
The first dimension of all the tensors is time dimension and the second
dimesion is the batch dimension.
Args:
info (namedtuple): experience collected from ``unroll()`` or
a replay buffer. All tensors are time-major. ``info`` should
contain the following fields:
- reward:
- step_type:
- discount:
target_value (torch.Tensor): the time-major tensor for the value at
each time step. This is used to calculate return. ``target_value``
can be same as ``value``.
Returns:
td_target
"""
if self._lambda == 1.0:
returns = value_ops.discounted_return(
rewards=info.reward,
values=target_value,
step_types=info.step_type,
discounts=info.discount * self._gamma)
elif self._lambda == 0.0:
returns = value_ops.one_step_discounted_return(
rewards=info.reward,
values=target_value,
step_types=info.step_type,
discounts=info.discount * self._gamma)
else:
advantages = value_ops.generalized_advantage_estimation(
rewards=info.reward,
values=target_value,
step_types=info.step_type,
discounts=info.discount * self._gamma,
td_lambda=self._lambda)
returns = advantages + target_value[:-1]
disc_ret = ()
if hasattr(info, "discounted_return"):
disc_ret = info.discounted_return
if disc_ret != ():
with alf.summary.scope(self._name):
episode_ended = disc_ret > self._default_return
alf.summary.scalar("episodic_discounted_return_all",
torch.mean(disc_ret[episode_ended]))
alf.summary.scalar(
"value_episode_ended_all",
torch.mean(value[:-1][:, episode_ended[0, :]]))
return returns
[docs] def forward(self, info: namedtuple, value: torch.Tensor,
target_value: torch.Tensor):
"""Calculate the loss.
The first dimension of all the tensors is time dimension and the second
dimesion is the batch dimension.
Args:
info: experience collected from ``unroll()`` or
a replay buffer. All tensors are time-major. ``info`` should
contain the following fields:
- reward:
- step_type:
- discount:
value: the time-major tensor for the value at each time
step. The loss is between this and the calculated return.
target_value: the time-major tensor for the value at
each time step. This is used to calculate return. ``target_value``
can be same as ``value``.
Returns:
LossInfo: with the ``extra`` field same as ``loss``.
"""
returns = self.compute_td_target(info, target_value)
value = value[:-1]
if self._normalize_target:
if self._target_normalizer is None:
self._target_normalizer = AdaptiveNormalizer(
alf.TensorSpec(value.shape[2:]),
auto_update=False,
debug_summaries=self._debug_summaries,
name=self._name + ".target_normalizer")
self._target_normalizer.update(returns)
returns = self._target_normalizer.normalize(returns)
value = self._target_normalizer.normalize(value)
if self._debug_summaries and alf.summary.should_record_summaries():
mask = info.step_type[:-1] != StepType.LAST
with alf.summary.scope(self._name):
def _summarize(v, r, td, suffix):
alf.summary.scalar(
"explained_variance_of_return_by_value" + suffix,
tensor_utils.explained_variance(v, r, mask))
safe_mean_hist_summary('values' + suffix, v, mask)
safe_mean_hist_summary('returns' + suffix, r, mask)
safe_mean_hist_summary("td_error" + suffix, td, mask)
if value.ndim == 2:
_summarize(value, returns, returns - value, '')
else:
td = returns - value
for i in range(value.shape[2]):
suffix = '/' + str(i)
_summarize(value[..., i], returns[..., i], td[..., i],
suffix)
loss = self._td_error_loss_fn(returns.detach(), value)
if loss.ndim == 3:
# Multidimensional reward. Average over the critic loss for all dimensions
loss = loss.mean(dim=2)
# The shape of the loss expected by Algorith.update_with_gradient is
# [T, B], so we need to augment it with additional zeros.
loss = tensor_utils.tensor_extend_zero(loss)
return LossInfo(loss=loss, extra=loss)
[docs]@alf.configurable
class TDQRLoss(TDLoss):
"""Temporal difference quantile regression loss.
Compared to TDLoss, GAE support has not been implemented. """
def __init__(self,
num_quantiles: int = 50,
gamma: Union[float, List[float]] = 0.99,
td_error_loss_fn: Callable = losses.huber_function,
td_lambda: float = 1.0,
sum_over_quantiles: bool = False,
debug_summaries: bool = False,
name: str = "TDQRLoss"):
"""
Args:
num_quantiles: the number of quantiles.
gamma: A discount factor for future rewards. For
multi-dim reward, this can also be a list of discounts, each
discount applies to a reward dim.
td_error_loss_fn: A function for computing the TD errors
loss. This function takes as input the target and the estimated
Q values and returns the loss for each element of the batch.
td_lambda: Lambda parameter for TD-lambda computation. Currently
only supports 1 and 0.
sum_over_quantiles: If True, the quantile regression loss will
be summed along the quantile dimension. Otherwise, it will be
averaged along the quantile dimension instead. Default is False.
debug_summaries: True if debug summaries should be created
name: The name of this loss.
"""
assert td_lambda in (0, 1), (
"Currently GAE is not supported, so td_lambda has to be 0 or 1.")
super().__init__(
gamma=gamma,
td_error_loss_fn=td_error_loss_fn,
td_lambda=td_lambda,
debug_summaries=debug_summaries,
name=name)
self._num_quantiles = num_quantiles
self._cdf_midpoints = (torch.arange(
num_quantiles, dtype=torch.float32) + 0.5) / num_quantiles
self._sum_over_quantiles = sum_over_quantiles
[docs] def forward(self, info: namedtuple, value: torch.Tensor,
target_value: torch.Tensor):
"""Calculate the loss.
The first dimension of all the tensors is time dimension and the second
dimesion is the batch dimension.
Args:
info: experience collected from ``unroll()`` or
a replay buffer. All tensors are time-major. ``info`` should
contain the following fields:
- reward:
- step_type:
- discount:
value: the time-major tensor for the value at each time
step. The loss is between this and the calculated return.
target_value: the time-major tensor for the value at
each time step. This is used to calculate return. ``target_value``
can be same as ``value``.
Returns:
LossInfo: with the ``extra`` field same as ``loss``.
"""
assert value.shape[-1] == self._num_quantiles, (
"The input value should have same num_quantiles as pre-defiend.")
assert target_value.shape[-1] == self._num_quantiles, (
"The input target_value should have same num_quantiles as pre-defiend."
)
returns = self.compute_td_target(info, target_value)
value = value[:-1]
# for quantile regression TD, the value and target both have shape
# (T-1, B, n_quantiles) for scalar reward and
# (T-1, B, reward_dim, n_quantiles) for multi-dim reward.
# The quantile TD has shape
# (T-1, B, n_quantiles, n_quantiles) for scalar reward and
# (T-1, B, reward_dim, n_quantiles, n_quantiles) for multi-dim reward
quantiles = value.unsqueeze(-2)
quantiles_target = returns.detach().unsqueeze(-1)
diff = quantiles_target - quantiles
if self._debug_summaries and alf.summary.should_record_summaries():
mask = info.step_type[:-1] != StepType.LAST
with alf.summary.scope(self._name):
def _summarize(v, r, d, suffix):
cdf = (d <= 0).float().mean(-2)
mean_cdf = cdf.mean(0).mean(0)
alf.summary.histogram(
"explained_cdf_of_return_by_value_quantile" + suffix,
mean_cdf)
if value.ndim == 3:
_summarize(value, returns, diff, '')
else:
for i in range(value.shape[-2]):
suffix = '/' + str(i)
_summarize(value[..., i, :], returns[..., i, :],
diff[..., i, :, :], suffix)
huber_loss = self._td_error_loss_fn(diff)
loss = torch.abs(
(self._cdf_midpoints - (diff.detach() < 0).float())) * huber_loss
if self._sum_over_quantiles:
loss = loss.mean(-2).sum(-1)
else:
loss = loss.mean(dim=(-2, -1))
if loss.ndim == 3:
# Multidimensional reward. Average over the critic loss for all dimensions
loss = loss.mean(dim=2)
# The shape of the loss expected by Algorith.update_with_gradient is
# [T, B], so we need to augment it with additional zeros.
loss = tensor_utils.tensor_extend_zero(loss)
return LossInfo(loss=loss, extra=loss)