Source code for alf.algorithms.dqn_algorithm

# Copyright (c) 2022 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DQN Algorithm."""

import torch
import torch.distributions as td
from typing import Callable, Optional, Union

import alf
from alf.algorithms.config import TrainerConfig
from alf.algorithms.sac_algorithm import SacAlgorithm, ActionType, \
    SacState as DqnState, SacCriticState as DqnCriticState, \
    SacActionState as DqnActionState, \
    SacInfo as DqnInfo, SacCriticInfo as DqnCriticInfo, \
    SacLossInfo as DqnLossInfo
from alf.algorithms.td_loss import TDLoss
from alf.data_structures import AlgStep, LossInfo, TimeStep
from alf.environments.alf_environment import AlfEnvironment
from alf.networks import QNetwork
from alf.optimizers import AdamTF
from alf.tensor_specs import TensorSpec, BoundedTensorSpec
from alf.utils import common, dist_utils
from alf.utils.schedulers import as_scheduler, Scheduler


[docs]@alf.configurable class DqnAlgorithm(SacAlgorithm): r"""DQN/DDQN algorithm: :: Mnih et al "Playing Atari with Deep Reinforcement Learning", arXiv:1312.5602 Hasselt et al "Deep Reinforcement Learning with Double Q-learning", arXiv:1509.06461 The difference with DQN is that a minimum is taken from the two critics, similar to TD3, instead of choosing the maximum action using the Q network and evaluating the action value using the target Q network. The implementation is based on the SAC algorithm. """ def __init__(self, observation_spec: alf.tensor_specs.NestedTensorSpec, action_spec: alf.tensor_specs.BoundedTensorSpec, reward_spec: TensorSpec = TensorSpec(()), q_network_cls: Callable[..., QNetwork] = QNetwork, q_optimizer: Optional[torch.optim.Optimizer] = None, rollout_epsilon_greedy: Union[float, Scheduler] = 0.1, target_net_target_action: bool = True, num_critic_replicas: int = 2, env: Optional[AlfEnvironment] = None, config: Optional[TrainerConfig] = None, critic_loss_ctor: Optional[Callable[..., TDLoss]] = None, checkpoint=None, debug_summaries: bool = False, name: str = "DqnAlgorithm"): """ Args: observation_spec (nested TensorSpec): representing the observations. action_spec (BoundedTensorSpec): Only one discrete action allowed. reward_spec (TensorSpec): a rank-1 or rank-0 tensor spec representing the reward(s). q_network: is used to construct QNetwork for estimating ``Q(s,a)`` given that the action is discrete. Its output spec must be consistent with the discrete action in ``action_spec``. q_optimizer: A custom optimizer for the q network. Uses the enclosing algorithm's optimizer if None. rollout_epsilon_greedy: epsilon greedy policy for rollout. Together with the following two parameters, the SAC algorithm can be converted to a DQN or DDQN algorithm when e.g. ``rollout_epsilon_greedy=0.3``, ``max_target_action=True``, and ``use_entropy_reward=False``. target_net_target_action: when ``True`` uses target critic network to get target action (similar as DDPG). When ``False``, uses critic network to get target action (similar as DDQN/SAC). num_critic_replicas: number of critics to be used. Default is 2. env: The environment to interact with. ``env`` is a batched environment, which means that it runs multiple simulations simultateously. ``env` only needs to be provided to the root algorithm. config: config for training. It only needs to be provided to the algorithm which performs ``train_iter()`` by itself. critic_loss_ctor: a critic loss constructor. If ``None``, a default ``OneStepTDLoss`` will be used. checkpoint (None|str): a string in the format of "prefix@path", where the "prefix" is the multi-step path to the contents in the checkpoint to be loaded. "path" is the full path to the checkpoint file saved by ALF. Refer to ``Algorithm`` for more details. debug_summaries (bool): True if debug summaries should be created. name (str): The name of this algorithm. """ self._rollout_epsilon_greedy = as_scheduler(rollout_epsilon_greedy) self._target_net_target_action = target_net_target_action # Disable alpha learning: alpha_optimizer = AdamTF(lr=0) super().__init__( observation_spec=observation_spec, action_spec=action_spec, reward_spec=reward_spec, actor_network_cls=None, critic_network_cls=None, q_network_cls=q_network_cls, # Do not use entropy reward: use_entropy_reward=False, num_critic_replicas=num_critic_replicas, env=env, config=config, critic_loss_ctor=critic_loss_ctor, # Allow custom optimizer for q_network: critic_optimizer=q_optimizer, alpha_optimizer=alpha_optimizer, checkpoint=checkpoint, debug_summaries=debug_summaries, name=name) assert self._act_type == ActionType.Discrete # Copied and modified from sac_algorithm (discrete actions). def _predict_action(self, observation, state: DqnActionState, epsilon_greedy=None, eps_greedy_sampling=False): new_state = DqnActionState() critic_network_inputs = (observation, None) # NOTE: This block departs from SAC: if eps_greedy_sampling or not self._target_net_target_action: # SAC always uses critic_networks to obtain action, # even during training nets = self._critic_networks else: nets = self._target_critic_networks q_values, critic_state = self._compute_critics( nets, *critic_network_inputs, state.critic) new_state = new_state._replace(critic=critic_state) # NOTE: This block departs from SAC: size = q_values.shape # [B, actions] if eps_greedy_sampling: # Epsilon greedy for rollout or evaluation. rand_act_prob = epsilon_greedy / size[-1] probs = torch.ones_like(q_values) * rand_act_prob if epsilon_greedy >= 1: # Uniform random action distribution greedy_act_prob = rand_act_prob else: # Epsilon greedy greedy_act_prob = 1 - epsilon_greedy + rand_act_prob else: # Greedy for train_step to obtain target value from target network. # The greedy action here is the maximizer of q values, and will be used # to obtain target value using the target network. probs = torch.zeros_like(q_values) greedy_act_prob = 1 greedy_action = torch.argmax(q_values, dim=-1) probs[torch.arange(size[0]), greedy_action] = greedy_act_prob action_dist = td.Categorical(probs=probs) action = dist_utils.sample_action_distribution(action_dist) return action_dist, action, q_values, new_state # Copied and modified from sac_algorithm (discrete actions).
[docs] def rollout_step(self, inputs: TimeStep, state: DqnState): """``rollout_step()`` basically predicts actions like what is done by ``predict_step()``. Additionally, if states are to be stored a in replay buffer, then this function also call ``_critic_networks`` and ``_target_critic_networks`` to maintain their states. """ action_dist, action, _, action_state = self._predict_action( inputs.observation, state=state.action, # NOTE: This is the only departure from SAC. epsilon_greedy=self._rollout_epsilon_greedy(), eps_greedy_sampling=True) if self.need_full_rollout_state(): _, critics_state = self._compute_critics( self._critic_networks, inputs.observation, action, state.critic.critics) _, target_critics_state = self._compute_critics( self._target_critic_networks, inputs.observation, action, state.critic.target_critics) critic_state = DqnCriticState( critics=critics_state, target_critics=target_critics_state) actor_state = () else: actor_state = state.actor critic_state = state.critic new_state = DqnState( action=action_state, actor=actor_state, critic=critic_state) return AlgStep( output=action, state=new_state, info=DqnInfo(action=action, action_distribution=action_dist))
[docs] def calc_loss(self, info: DqnInfo): # Adapted from SAC: Removes irrelevant losses and logging. critic_loss = self._calc_critic_loss(info) return LossInfo( loss=critic_loss.loss, priority=critic_loss.priority, extra=DqnLossInfo(critic=critic_loss.extra, actor=(), alpha=()))