# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-Dimensional Q-Learning Algorithm."""
import functools
import torch
import torch.nn as nn
from typing import Callable
import alf
from alf.algorithms.config import TrainerConfig
from alf.algorithms.off_policy_algorithm import OffPolicyAlgorithm
from alf.algorithms.one_step_loss import OneStepTDLoss
from alf.algorithms.rl_algorithm import RLAlgorithm
from alf.algorithms.sac_algorithm import _set_target_entropy
from alf.data_structures import TimeStep, Experience, LossInfo, namedtuple
from alf.data_structures import AlgStep
from alf.nest import nest
from alf.networks import MdqCriticNetwork
from alf.tensor_specs import TensorSpec, BoundedTensorSpec
from alf.utils import (losses, common, dist_utils, math_ops, spec_utils,
tensor_utils)
MdqCriticState = namedtuple("MdqCriticState", ['critic', 'target_critic'])
MdqCriticInfo = namedtuple("MdqCriticInfo", [
"critic_free_form", "target_critic_free_form", "critic_adv_form",
"distill_target", "kl_wrt_prior"
])
MdqState = namedtuple("MdqState", ['critic'])
MdqAlphaInfo = namedtuple("MdqAlphaInfo", ["alpha_loss", "neg_entropy"])
MdqInfo = namedtuple(
"MdqInfo",
["reward", "step_type", "discount", "action", "critic", "alpha"],
default_value=())
MdqLossInfo = namedtuple('MdqLossInfo', ['critic', 'distill', 'alpha'])
[docs]@alf.configurable
class MdqAlgorithm(OffPolicyAlgorithm):
"""Multi-Dimensional Q-Learning Algorithm.
"""
def __init__(
self,
observation_spec,
action_spec: BoundedTensorSpec,
critic_network: MdqCriticNetwork,
reward_spec=TensorSpec(()),
epsilon_greedy=None,
env=None,
config: TrainerConfig = None,
critic_loss_ctor=None,
target_entropy=dist_utils.calc_default_target_entropy_quantized,
initial_log_alpha=0.0,
target_update_tau=0.05,
target_update_period=1,
distill_noise=0.01,
critic_optimizer=None,
alpha_optimizer=None,
debug_summaries=False,
name="MdqAlgorithm"):
"""
Args:
observation_spec (nested TensorSpec): representing the observations.
action_spec (nested BoundedTensorSpec): representing the actions.
critic_network (MdqCriticNetwork): an instance of MdqCriticNetwork
reward_spec (TensorSpec): a rank-1 or rank-0 tensor spec representing
the reward(s).
epsilon_greedy (float): a floating value in [0,1], representing the
chance of action sampling instead of taking argmax. This can
help prevent a dead loop in some deterministic environment like
Breakout. Only used for evaluation. If None, its value is taken
from ``config.epsilon_greedy`` and then
``alf.get_config_value(TrainerConfig.epsilon_greedy)``.
env (Environment): The environment to interact with. ``env`` is a
batched environment, which means that it runs multiple simulations
simultateously. ``env` only needs to be provided to the root
algorithm.
config (TrainerConfig): config for training. It only needs to be
provided to the algorithm which performs ``train_iter()`` by
itself.
critic_loss_ctor (None|OneStepTDLoss|MultiStepLoss): a critic loss
constructor. If ``None``, a default ``OneStepTDLoss`` will be used.
initial_log_alpha (float): initial value for variable ``log_alpha``.
target_entropy (float|Callable): If a floating value, it's the
target average policy entropy, for updating ``alpha``. If a
callable function, then it will be called on the action spec to
calculate a target entropy. Note that in MDQ algorithm, as the
continuous action is represented by a discrete distribution for
each action dimension, ``calc_default_target_entropy_quantized``
is used to compute the target entropy by default.
target_update_tau (float): Factor for soft update of the target
networks.
target_update_period (int): Period for soft update of the target
networks.
distill_noise (int): the std of random Gaussian noise added to the
action used for distillation.
critic_optimizer (torch.optim.optimizer): The optimizer for critic.
alpha_optimizer (torch.optim.optimizer): The optimizer for alpha.
debug_summaries (bool): True if debug summaries should be created.
name (str): The name of this algorithm.
"""
if epsilon_greedy is None:
epsilon_greedy = alf.utils.common.get_epsilon_greedy(config)
self._epsilon_greedy = epsilon_greedy
critic_networks = critic_network
target_critic_networks = critic_networks.copy(
name='target_critic_networks')
train_state_spec = MdqState(
critic=MdqCriticState(
critic=critic_networks.state_spec,
target_critic=critic_networks.state_spec))
super().__init__(
observation_spec,
action_spec,
reward_spec=reward_spec,
train_state_spec=train_state_spec,
env=env,
config=config,
debug_summaries=debug_summaries,
name=name)
self._critic_networks = critic_networks
self._target_critic_networks = target_critic_networks
self.add_optimizer(critic_optimizer, [critic_networks])
if critic_loss_ctor is None:
critic_loss_ctor = OneStepTDLoss
critic_loss_ctor = functools.partial(
critic_loss_ctor, debug_summaries=debug_summaries)
flat_action_spec = nest.flatten(self._action_spec)
self._flat_action_spec = flat_action_spec
self._action_dim = flat_action_spec[0].shape[0]
self._log_pi_uniform_prior = self._critic_networks.get_uniform_prior_logpi(
)
self._num_critic_replicas = self._critic_networks._num_critic_replicas
self._critic_losses = []
for i in range(self._num_critic_replicas):
self._critic_losses.append(
critic_loss_ctor(name="critic_loss%d" % (i + 1)))
self._is_continuous = flat_action_spec[0].is_continuous
self._target_entropy = _set_target_entropy(self.name, target_entropy,
flat_action_spec)
log_alpha = nn.Parameter(torch.tensor(float(initial_log_alpha)))
self._log_alpha = log_alpha
self._update_target = common.TargetUpdater(
models=[self._critic_networks],
target_models=[self._target_critic_networks],
tau=target_update_tau,
period=target_update_period)
if alpha_optimizer is not None:
self.add_optimizer(alpha_optimizer, [log_alpha])
self._distill_noise = distill_noise
def _predict(self, time_step: TimeStep, state=None, epsilon_greedy=1.):
# Note that here get_action will do greedy sampling only if
# epsilon_greedy is 0. This option is provided for evaluation purpose
# if greedy sampling is desirable.
action, _ = self._critic_networks.get_action(
time_step.observation,
alpha=torch.exp(self._log_alpha).detach(),
greedy=(epsilon_greedy == 0))
# slice over action when num_critic_replicas > 1
# [B, n, d] -> [B, d]
action = action[:, 0, :]
empty_state = nest.map_structure(lambda x: (), self.train_state_spec)
return AlgStep(
output=action, state=empty_state, info=MdqInfo(action=action))
[docs] def predict_step(self, time_step: TimeStep, state):
return self._predict(time_step, state, self._epsilon_greedy)
[docs] def rollout_step(self, time_step: TimeStep, state):
if self.need_full_rollout_state():
raise NotImplementedError("Storing RNN state to replay buffer "
"is not supported by SacAlgorithm")
return self._predict(time_step, state, epsilon_greedy=1.0)
def _critic_train_step(self, inputs: TimeStep, state: MdqCriticState,
rollout_action, action, log_pi_per_dim):
alpha = self._log_alpha.exp().detach()
critic_input = (inputs.observation, rollout_action.to(torch.float32))
target_critic_input = (inputs.observation, action.detach())
# [B, n]
critic, critic_state = self._critic_networks(
torch.cat(critic_input, -1),
alpha=alpha,
state=state.critic,
free_form=True)
noisy_distill_action = self._get_noisy_action(
action,
self._action_spec,
self._distill_noise,
noise_clip=0,
spec_clip=True)
critic_distill_input = (inputs.observation,
noisy_distill_action.detach())
# [B, n, action_dim]
critic_adv_form, critic_state = self._critic_networks(
critic_distill_input,
alpha=alpha,
state=state.critic,
free_form=False)
target_critic_input_new = (tensor_utils.tensor_extend_new_dim(
target_critic_input[0], dim=1, n=self._num_critic_replicas),
target_critic_input[1])
distill_critic_input_new = (tensor_utils.tensor_extend_new_dim(
critic_distill_input[0], dim=1, n=self._num_critic_replicas),
critic_distill_input[1])
target_critic, target_critic_state = self._target_critic_networks(
torch.cat(target_critic_input_new, -1),
alpha=alpha,
state=state.target_critic,
free_form=True)
# Note that in MDQ we distill from the target_critic_network.
distill_target, _ = self._target_critic_networks(
torch.cat(distill_critic_input_new, -1),
alpha=alpha,
state=state.target_critic,
free_form=True)
kl_wrt_prior_per_dim = log_pi_per_dim - self._log_pi_uniform_prior
# keeping the KL of all actions dimensions in case it is useful
# in some cases in the future, e.g., per-action target correction using
# the corresponding KL
kl_wrt_prior = tensor_utils.reverse_cumsum(
kl_wrt_prior_per_dim, dim=-1)
info = MdqCriticInfo(
critic_free_form=critic,
target_critic_free_form=target_critic,
distill_target=distill_target,
critic_adv_form=critic_adv_form,
kl_wrt_prior=kl_wrt_prior)
state = MdqCriticState(
critic=critic_state, target_critic=target_critic_state)
return state, info
def _alpha_train_step(self, log_pi_per_dim):
""" Adjusting alpha according to target entropy.
Args:
log_pi_per_dim (torch.Tensor): a tensor of the shape
[B, n, action_dim] representing the log_pi for each dimension
of the sampled multi-dimensional action
"""
log_pi_full = log_pi_per_dim.sum(dim=-1)
alpha_loss = self._log_alpha * (
-log_pi_full - self._target_entropy).detach()
# mean over critic
alpha_loss = torch.mean(alpha_loss, -1).view(-1)
neg_entropy = torch.mean(log_pi_full.squeeze(-1), -1).view(-1)
info = LossInfo(
loss=alpha_loss,
extra=MdqAlphaInfo(alpha_loss=alpha_loss, neg_entropy=neg_entropy))
return info
[docs] def train_step(self, inputs: TimeStep, state: MdqState, rollout_info):
alpha = torch.exp(self._log_alpha).detach()
action, log_pi_per_dim = self._critic_networks.get_action(
inputs.observation, alpha=alpha, greedy=False)
action = action[:, 0:1, :].expand_as(action)
log_pi_per_dim = log_pi_per_dim[:, 0:1, :].expand_as(log_pi_per_dim)
critic_state, critic_info = self._critic_train_step(
inputs, state.critic, rollout_info.action, action, log_pi_per_dim)
alpha_info = self._alpha_train_step(log_pi_per_dim)
state = MdqState(critic=critic_state)
info = MdqInfo(
reward=inputs.reward,
step_type=inputs.step_type,
discount=inputs.discount,
critic=critic_info,
alpha=alpha_info)
return AlgStep(action, state, info)
[docs] def after_update(self, root_inputs, info: MdqInfo):
# sync parallel/non-parallel network parameters
# need to syn net first in the case of using target net as policy
self._critic_networks.sync_net()
self._update_target()
[docs] def calc_loss(self, info: MdqInfo):
alpha_loss = info.alpha
critic_loss, distill_loss = self._calc_critic_loss(info)
total_loss = critic_loss.loss + distill_loss + alpha_loss.loss.squeeze(
-1)
return LossInfo(
loss=total_loss,
extra=MdqLossInfo(
critic=critic_loss.extra,
alpha=alpha_loss.extra,
distill=distill_loss))
def _calc_critic_loss(self, train_info: MdqInfo):
critic_info = train_info.critic
# [t, B, n]
critic_free_form = critic_info.critic_free_form
# [t, B, n, action_dim]
critic_adv_form = critic_info.critic_adv_form
target_critic_free_form = critic_info.target_critic_free_form
distill_target = critic_info.distill_target
num_critic_replicas = critic_free_form.shape[2]
alpha = torch.exp(self._log_alpha).detach()
kl_wrt_prior = critic_info.kl_wrt_prior
# [t, B, n, action_dim] -> [t, B]
# note that currently the kl_wrt_prior is independent of ensembles,
# we therefore slice over ensemble by taking the first element;
# for the aciton dimension, the first element is the full KL
kl_wrt_prior = kl_wrt_prior[..., 0, 0]
# [t, B, n] -> [t, B]
target_critic, min_target_ind = torch.min(
target_critic_free_form, dim=2)
# [t, B, n] -> [t, B]
distill_target, _ = torch.min(distill_target, dim=2)
target_critic_corrected = target_critic - alpha * kl_wrt_prior
critic_losses = []
for j in range(num_critic_replicas):
critic_losses.append(self._critic_losses[j](
info=train_info,
value=critic_free_form[:, :, j],
target_value=target_critic_corrected).loss)
critic_loss = math_ops.add_n(critic_losses)
distill_loss = (
critic_adv_form[..., -1] - distill_target.unsqueeze(2).detach())**2
# mean over replica
distill_loss = distill_loss.mean(dim=2)
return LossInfo(
loss=critic_loss,
extra=critic_loss / len(critic_losses)), distill_loss
def _get_noisy_action(self,
actions,
action_specs,
noise_level,
noise_clip=0.0,
spec_clip=True):
if noise_level > 0:
max_action = torch.as_tensor(action_specs.maximum)
noise = torch.randn_like(actions) * noise_level * max_action
if noise_clip > 0:
noise = noise.clamp(
min=-noise_clip * max_action, max=noise_clip * max_action)
noisy_action = actions + noise
if spec_clip:
clipped_noisy_action = spec_utils.clip_to_spec(
noisy_action, action_specs)
else:
clipped_noisy_action = noisy_action
return clipped_noisy_action
else:
return actions
def _trainable_attributes_to_ignore(self):
return ['_target_critic_networks']