# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deep Deterministic Policy Gradient (DDPG)."""
import functools
import numpy as np
import torch
import torch.nn as nn
import torch.distributions as td
from typing import Callable
import alf
from alf.algorithms.config import TrainerConfig
from alf.algorithms.off_policy_algorithm import OffPolicyAlgorithm
from alf.algorithms.one_step_loss import OneStepTDLoss
from alf.algorithms.rl_algorithm import RLAlgorithm
from alf.data_structures import TimeStep, Experience, LossInfo, namedtuple
from alf.data_structures import AlgStep, StepType
from alf.nest import nest
import alf.nest.utils as nest_utils
from alf.networks import ActorNetwork, CriticNetwork
from alf.tensor_specs import TensorSpec, BoundedTensorSpec
from alf.utils import losses, common, dist_utils, math_ops, spec_utils
DdpgCriticState = namedtuple("DdpgCriticState",
['critics', 'target_actor', 'target_critics'])
DdpgCriticInfo = namedtuple("DdpgCriticInfo", ["q_values", "target_q_values"])
DdpgActorState = namedtuple("DdpgActorState", ['actor', 'critics'])
DdpgState = namedtuple("DdpgState", ['actor', 'critics'])
DdpgInfo = namedtuple(
"DdpgInfo", [
"reward", "step_type", "discount", "action", "action_distribution",
"actor_loss", "critic", "discounted_return"
],
default_value=())
DdpgLossInfo = namedtuple('DdpgLossInfo', ('actor', 'critic'))
[docs]@alf.configurable
class DdpgAlgorithm(OffPolicyAlgorithm):
"""Deep Deterministic Policy Gradient (DDPG).
Reference:
Lillicrap et al "Continuous control with deep reinforcement learning"
https://arxiv.org/abs/1509.02971
"""
def __init__(self,
observation_spec,
action_spec: BoundedTensorSpec,
reward_spec=TensorSpec(()),
actor_network_ctor=ActorNetwork,
critic_network_ctor=CriticNetwork,
reward_weights=None,
epsilon_greedy=None,
calculate_priority=False,
env=None,
config: TrainerConfig = None,
ou_stddev=0.2,
ou_damping=0.15,
critic_loss_ctor=None,
num_critic_replicas=1,
target_update_tau=0.05,
target_update_period=1,
rollout_random_action=0.,
dqda_clipping=None,
action_l2=0,
actor_optimizer=None,
critic_optimizer=None,
checkpoint=None,
debug_summaries=False,
name="DdpgAlgorithm"):
"""
Args:
observation_spec (nested TensorSpec): representing the observations.
action_spec (nested BoundedTensorSpec): representing the actions.
reward_spec (TensorSpec): a rank-1 or rank-0 tensor spec representing
the reward(s).
actor_network_ctor (Callable): Function to construct the actor network.
``actor_network_ctor`` needs to accept ``input_tensor_spec`` and
``action_spec`` as its arguments and return an actor network.
The constructed network will be called with ``forward(observation, state)``.
critic_network_ctor (Callable): Function to construct the critic
network. ``critic_netwrok_ctor`` needs to accept ``input_tensor_spec``
which is a tuple of ``(observation_spec, action_spec)``. The
constructed network will be called with
``forward((observation, action), state)``.
reward_weights (list[float]): this is only used when the reward is
multidimensional. In that case, the weighted sum of the q values
is used for training the actor.
epsilon_greedy (float): a floating value in [0,1], representing the
chance of action sampling instead of taking argmax. This can
help prevent a dead loop in some deterministic environment like
Breakout. Only used for evaluation. If None, its value is taken
from ``config.epsilon_greedy`` and then
``alf.get_config_value(TrainerConfig.epsilon_greedy)``.
calculate_priority (bool): whether to calculate priority. This is
only useful if priority replay is enabled.
num_critic_replicas (int): number of critics to be used. Default is 1.
env (Environment): The environment to interact with. env is a batched
environment, which means that it runs multiple simulations
simultateously. ``env`` only needs to be provided to the root
algorithm.
config (TrainerConfig): config for training. config only needs to be
provided to the algorithm which performs ``train_iter()`` by
itself.
ou_stddev (float): Standard deviation for the Ornstein-Uhlenbeck
(OU) noise added in the default collect policy.
ou_damping (float): Damping factor for the OU noise added in the
default collect policy.
critic_loss_ctor (None|OneStepTDLoss|MultiStepLoss): a critic loss
constructor. If ``None``, a default ``OneStepTDLoss`` will be used.
target_update_tau (float): Factor for soft update of the target
networks.
target_update_period (int): Period for soft update of the target
networks.
rollout_random_action (float): the probability of taking a uniform
random action during a ``rollout_step()``. 0 means always directly
taking actions added with OU noises and 1 means always sample
uniformly random actions. A bigger value results in more
exploration during rollout.
dqda_clipping (float): when computing the actor loss, clips the
gradient dqda element-wise between ``[-dqda_clipping, dqda_clipping]``.
Does not perform clipping if ``dqda_clipping == 0``.
action_l2 (float): weight of squared action l2-norm on actor loss.
actor_optimizer (torch.optim.optimizer): The optimizer for actor.
critic_optimizer (torch.optim.optimizer): The optimizer for critic.
checkpoint (None|str): a string in the format of "prefix@path",
where the "prefix" is the multi-step path to the contents in the
checkpoint to be loaded. "path" is the full path to the checkpoint
file saved by ALF. Refer to ``Algorithm`` for more details.
debug_summaries (bool): True if debug summaries should be created.
name (str): The name of this algorithm.
"""
self._calculate_priority = calculate_priority
if epsilon_greedy is None:
epsilon_greedy = alf.utils.common.get_epsilon_greedy(config)
self._epsilon_greedy = epsilon_greedy
critic_network = critic_network_ctor(
input_tensor_spec=(observation_spec, action_spec),
output_tensor_spec=reward_spec)
actor_network = actor_network_ctor(
input_tensor_spec=observation_spec, action_spec=action_spec)
critic_networks = critic_network.make_parallel(num_critic_replicas)
self._action_l2 = action_l2
train_state_spec = DdpgState(
actor=DdpgActorState(
actor=actor_network.state_spec,
critics=critic_networks.state_spec),
critics=DdpgCriticState(
critics=critic_networks.state_spec,
target_actor=actor_network.state_spec,
target_critics=critic_networks.state_spec))
super().__init__(
observation_spec=observation_spec,
action_spec=action_spec,
reward_spec=reward_spec,
train_state_spec=train_state_spec,
reward_weights=reward_weights,
env=env,
config=config,
checkpoint=checkpoint,
debug_summaries=debug_summaries,
name=name)
if actor_optimizer is not None:
self.add_optimizer(actor_optimizer, [actor_network])
if critic_optimizer is not None:
self.add_optimizer(critic_optimizer, [critic_networks])
self._actor_network = actor_network
self._num_critic_replicas = num_critic_replicas
self._critic_networks = critic_networks
self._target_actor_network = actor_network.copy(
name='target_actor_networks')
self._target_critic_networks = critic_networks.copy(
name='target_critic_networks')
self._rollout_random_action = float(rollout_random_action)
if critic_loss_ctor is None:
critic_loss_ctor = OneStepTDLoss
critic_loss_ctor = functools.partial(
critic_loss_ctor, debug_summaries=debug_summaries)
self._critic_losses = [None] * num_critic_replicas
for i in range(num_critic_replicas):
self._critic_losses[i] = critic_loss_ctor(
name=("critic_loss" + str(i)))
self._ou_process = common.create_ou_process(action_spec, ou_stddev,
ou_damping)
self._update_target = common.TargetUpdater(
models=[self._actor_network, self._critic_networks],
target_models=[
self._target_actor_network, self._target_critic_networks
],
tau=target_update_tau,
period=target_update_period)
self._dqda_clipping = dqda_clipping
[docs] def predict_step(self, inputs: TimeStep, state):
return self._predict_step(inputs, state, self._epsilon_greedy)
def _predict_step(self, time_step: TimeStep, state, epsilon_greedy=1.):
action, state = self._actor_network(
time_step.observation, state=state.actor.actor)
empty_state = nest.map_structure(lambda x: (), self.train_state_spec)
def _sample(a, ou):
if epsilon_greedy == 0:
return a
elif epsilon_greedy >= 1.0:
return a + ou()
else:
ind_explore = torch.where(
torch.rand(a.shape[:1]) < epsilon_greedy)
noisy_a = a + ou()
a[ind_explore[0], :] = noisy_a[ind_explore[0], :]
return a
noisy_action = nest.map_structure(_sample, action, self._ou_process)
noisy_action = nest.map_structure(spec_utils.clip_to_spec,
noisy_action, self._action_spec)
state = empty_state._replace(
actor=DdpgActorState(actor=state, critics=()))
return AlgStep(
output=noisy_action,
state=state,
info=DdpgInfo(action=noisy_action, action_distribution=action))
[docs] def rollout_step(self, time_step: TimeStep, state=None):
if self.need_full_rollout_state():
raise NotImplementedError("Storing RNN state to replay buffer "
"is not supported by DdpgAlgorithm")
def _update_random_action(spec, noisy_action):
random_action = spec_utils.scale_to_spec(
torch.rand_like(noisy_action) * 2 - 1, spec)
ind = torch.where(
torch.rand(noisy_action.shape[:1]) < self.
_rollout_random_action)
noisy_action[ind[0], :] = random_action[ind[0], :]
pred_step = self._predict_step(time_step, state, epsilon_greedy=1.0)
if self._rollout_random_action > 0:
nest.map_structure(_update_random_action, self._action_spec,
pred_step.output)
return pred_step
def _critic_train_step(self, inputs: TimeStep, state: DdpgCriticState,
rollout_info: DdpgInfo):
target_action, target_actor_state = self._target_actor_network(
inputs.observation, state=state.target_actor)
target_q_values, target_critic_states = self._target_critic_networks(
(inputs.observation, target_action), state=state.target_critics)
if self.has_multidim_reward():
sign = self.reward_weights.sign()
target_q_values = (target_q_values * sign).min(dim=1)[0] * sign
else:
target_q_values = target_q_values.min(dim=1)[0]
q_values, critic_states = self._critic_networks(
(inputs.observation, rollout_info.action), state=state.critics)
state = DdpgCriticState(
critics=critic_states,
target_actor=target_actor_state,
target_critics=target_critic_states)
info = DdpgCriticInfo(
q_values=q_values, target_q_values=target_q_values)
return state, info
def _actor_train_step(self, inputs: TimeStep, state: DdpgActorState):
action, actor_state = self._actor_network(
inputs.observation, state=state.actor)
q_values, critic_states = self._critic_networks(
(inputs.observation, action), state=state.critics)
if self.has_multidim_reward():
# Multidimensional reward: [B, replicas, reward_dim]
q_values = q_values * self.reward_weights
# min over replicas
q_value = q_values.min(dim=1)[0]
# This sum() will reduce all dims so q_value can be any rank
dqda = nest_utils.grad(action, q_value.sum())
def actor_loss_fn(dqda, action):
if self._dqda_clipping:
dqda = torch.clamp(dqda, -self._dqda_clipping,
self._dqda_clipping)
loss = 0.5 * losses.element_wise_squared_loss(
(dqda + action).detach(), action)
if self._action_l2 > 0:
assert action.requires_grad
loss += self._action_l2 * (action**2)
loss = loss.sum(list(range(1, loss.ndim)))
return loss
actor_loss = nest.map_structure(actor_loss_fn, dqda, action)
state = DdpgActorState(actor=actor_state, critics=critic_states)
info = LossInfo(loss=sum(nest.flatten(actor_loss)), extra=actor_loss)
return AlgStep(output=action, state=state, info=info)
[docs] def train_step(self, inputs: TimeStep, state: DdpgState,
rollout_info: DdpgInfo):
critic_states, critic_info = self._critic_train_step(
inputs=inputs, state=state.critics, rollout_info=rollout_info)
policy_step = self._actor_train_step(inputs=inputs, state=state.actor)
return policy_step._replace(
state=DdpgState(actor=policy_step.state, critics=critic_states),
info=DdpgInfo(
reward=inputs.reward,
step_type=inputs.step_type,
discount=inputs.discount,
action_distribution=policy_step.output,
critic=critic_info,
actor_loss=policy_step.info,
discounted_return=rollout_info.discounted_return))
[docs] def calc_loss(self, info: DdpgInfo):
critic_losses = [None] * self._num_critic_replicas
for i in range(self._num_critic_replicas):
critic_losses[i] = self._critic_losses[i](
info=info,
value=info.critic.q_values[:, :, i, ...],
target_value=info.critic.target_q_values).loss
critic_loss = math_ops.add_n(critic_losses)
if self._calculate_priority:
valid_masks = (info.step_type != StepType.LAST).to(torch.float32)
valid_n = torch.clamp(valid_masks.sum(dim=0), min=1.0)
priority = (
(critic_loss * valid_masks).sum(dim=0) / valid_n).sqrt()
else:
priority = ()
actor_loss = info.actor_loss
return LossInfo(
loss=critic_loss + actor_loss.loss,
priority=priority,
extra=DdpgLossInfo(critic=critic_loss, actor=actor_loss.extra))
[docs] def after_update(self, root_inputs, info: DdpgInfo):
self._update_target()
def _trainable_attributes_to_ignore(self):
return ['_target_actor_network', '_target_critic_networks']