# Copyright (c) 2019 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
import torch
import numpy as np
import alf
from alf.data_structures import LossInfo
from alf.utils.losses import element_wise_squared_loss
from alf.utils.summary_utils import safe_mean_hist_summary
from alf.utils import tensor_utils, dist_utils, value_ops
from .algorithm import Loss
ActorCriticLossInfo = namedtuple("ActorCriticLossInfo",
["pg_loss", "td_loss", "neg_entropy"])
def _normalize_advantages(advantages, variance_epsilon=1e-8):
# advantages is of shape [T, B] or [T, B, N], where N is reward dim
# this function normalizes over all elements in the input advantages
shape = advantages.shape
# shape: [TB, 1] or [TB, N]
advantages = advantages.reshape(np.prod(advantages.shape[:2]), -1)
adv_mean = advantages.mean(0)
adv_var = torch.var(advantages, dim=0, unbiased=False)
normalized_advantages = (
(advantages - adv_mean) / (torch.sqrt(adv_var) + variance_epsilon))
return normalized_advantages.reshape(*shape)
[docs]@alf.configurable
class ActorCriticLoss(Loss):
def __init__(self,
gamma=0.99,
td_error_loss_fn=element_wise_squared_loss,
use_gae=False,
td_lambda=0.95,
use_td_lambda_return=True,
normalize_advantages=False,
advantage_clip=None,
entropy_regularization=None,
td_loss_weight=1.0,
debug_summaries=False,
name="ActorCriticLoss"):
"""An actor-critic loss equals to
.. code-block:: python
(policy_gradient_loss
+ td_loss_weight * td_loss
- entropy_regularization * entropy)
Args:
gamma (float|list[float]): A discount factor for future rewards. For
multi-dim reward, this can also be a list of discounts, each
discount applies to a reward dim.
td_errors_loss_fn (Callable): A function for computing the TD errors
loss. This function takes as input the target and the estimated
Q values and returns the loss for each element of the batch.
use_gae (bool): If True, uses generalized advantage estimation for
computing per-timestep advantage. Else, just subtracts value
predictions from empirical return.
use_td_lambda_return (bool): Only effective if use_gae is True.
If True, uses ``td_lambda_return`` for training value function.
``(td_lambda_return = gae_advantage + value_predictions)``.
td_lambda (float): Lambda parameter for TD-lambda computation.
normalize_advantages (bool): If True, normalize advantage to zero
mean and unit variance within batch for caculating policy
gradient. This is commonly used for PPO.
advantage_clip (float): If set, clip advantages to :math:`[-x, x]`
entropy_regularization (float): Coefficient for entropy
regularization loss term.
td_loss_weight (float): the weigt for the loss of td error.
"""
super().__init__(name=name)
self._td_loss_weight = td_loss_weight
self._name = name
self._gamma = torch.tensor(gamma)
self._td_error_loss_fn = td_error_loss_fn
self._use_gae = use_gae
self._lambda = td_lambda
self._use_td_lambda_return = use_td_lambda_return
self._normalize_advantages = normalize_advantages
assert advantage_clip is None or advantage_clip > 0, (
"Clipping value should be positive!")
self._advantage_clip = advantage_clip
self._entropy_regularization = entropy_regularization
self._debug_summaries = debug_summaries
@property
def gamma(self):
return self._gamma.clone()
[docs] def forward(self, info):
"""Cacluate actor critic loss. The first dimension of all the tensors is
time dimension and the second dimesion is the batch dimension.
Args:
info (namedtuple): information for calculating loss. All tensors are
time-major. It should contain the following fields:
- reward:
- step_type:
- discount:
- action:
- action_distribution:
- value:
Returns:
LossInfo: with ``extra`` being ``ActorCriticLossInfo``.
"""
value = info.value
returns, advantages = self._calc_returns_and_advantages(info, value)
if self._debug_summaries and alf.summary.should_record_summaries():
with alf.summary.scope(self._name):
def _summarize(v, r, adv, suffix):
alf.summary.scalar("values" + suffix, v.mean())
alf.summary.scalar("returns" + suffix, r.mean())
safe_mean_hist_summary('advantages' + suffix, adv)
alf.summary.scalar(
"explained_variance_of_return_by_value" + suffix,
tensor_utils.explained_variance(v, r))
if value.ndim == 2:
_summarize(value, returns, advantages, '')
else:
for i in range(value.shape[2]):
suffix = '/' + str(i)
_summarize(value[..., i], returns[..., i],
advantages[..., i], suffix)
if self._normalize_advantages:
advantages = _normalize_advantages(advantages)
if self._advantage_clip:
advantages = torch.clamp(advantages, -self._advantage_clip,
self._advantage_clip)
if info.reward_weights != ():
advantages = (advantages * info.reward_weights).sum(-1)
pg_loss = self._pg_loss(info, advantages.detach())
td_loss = self._td_error_loss_fn(returns.detach(), value)
if td_loss.ndim == 3:
td_loss = td_loss.mean(dim=2)
loss = pg_loss + self._td_loss_weight * td_loss
entropy_loss = ()
if self._entropy_regularization is not None:
entropy, entropy_for_gradient = dist_utils.entropy_with_fallback(
info.action_distribution, return_sum=False)
entropy_loss = alf.nest.map_structure(lambda x: -x, entropy)
loss -= self._entropy_regularization * sum(
alf.nest.flatten(entropy_for_gradient))
return LossInfo(
loss=loss,
extra=ActorCriticLossInfo(
td_loss=td_loss, pg_loss=pg_loss, neg_entropy=entropy_loss))
def _pg_loss(self, info, advantages):
action_log_prob = dist_utils.compute_log_probability(
info.action_distribution, info.action)
return -advantages * action_log_prob
def _calc_returns_and_advantages(self, info, value):
if info.reward.ndim == 3:
# [T, B, D] or [T, B, 1]
discounts = info.discount.unsqueeze(-1) * self._gamma
else:
# [T, B]
discounts = info.discount * self._gamma
returns = value_ops.discounted_return(
rewards=info.reward,
values=value,
step_types=info.step_type,
discounts=discounts)
returns = tensor_utils.tensor_extend(returns, value[-1])
if not self._use_gae:
advantages = returns - value
else:
advantages = value_ops.generalized_advantage_estimation(
rewards=info.reward,
values=value,
step_types=info.step_type,
discounts=discounts,
td_lambda=self._lambda)
advantages = tensor_utils.tensor_extend_zero(advantages)
if self._use_td_lambda_return:
returns = advantages + value
return returns, advantages
[docs] def calc_loss(self, info):
return self(info)