# Copyright (c) 2021 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimistic Actor Critic algorithm."""
import torch
import torch.distributions as td
import alf
from alf.algorithms.config import TrainerConfig
from alf.algorithms.sac_algorithm import SacAlgorithm, SacInfo, ActionType
from alf.algorithms.sac_algorithm import SacActionState, SacCriticState, SacState
from alf.data_structures import TimeStep
from alf.data_structures import AlgStep
from alf.nest import nest
import alf.nest.utils as nest_utils
from alf.networks import ActorDistributionNetwork, CriticNetwork
from alf.networks import QNetwork
from alf.networks.projection_networks import NormalProjectionNetwork
from alf.tensor_specs import TensorSpec, BoundedTensorSpec
from alf.utils import dist_utils
[docs]@alf.configurable
class OacAlgorithm(SacAlgorithm):
"""Optimistic Actor Critic algorithm, described in:
::
Ciosek et al "Better Exploration with Optimistic Actor-Critic", arXiv:1910.12807
"""
def __init__(self,
observation_spec,
action_spec: BoundedTensorSpec,
reward_spec=TensorSpec(()),
actor_network_cls=ActorDistributionNetwork,
critic_network_cls=CriticNetwork,
q_network_cls=QNetwork,
epsilon_greedy=None,
use_entropy_reward=True,
calculate_priority=False,
num_critic_replicas=2,
env=None,
config: TrainerConfig = None,
critic_loss_ctor=None,
target_entropy=None,
prior_actor_ctor=None,
target_kld_per_dim=3.,
initial_log_alpha=0.0,
explore=True,
explore_delta=6.8,
beta_ub=4.6,
max_log_alpha=None,
target_update_tau=0.05,
target_update_period=1,
dqda_clipping=None,
actor_optimizer=None,
critic_optimizer=None,
alpha_optimizer=None,
checkpoint=None,
debug_summaries=False,
name="OacAlgorithm"):
"""
Refer to SacAlgorithm for Args besides the following.
Args:
explore (bool): default is True for OAC algorithm, where only
continuous action space is supported. When 'explore' is False,
OAC is the same as SAC.
explore_delta (float): parameter controlling how optimistic in shifting
the mean of the target policy to get the mean of the explore policy.
beta_ub (float): parameter for computing the upperbound of Q value:
:math:`Q_ub(s,a) = \mu_Q(s,a) + \beta_ub * \sigma_Q(s,a)`
"""
super().__init__(
observation_spec,
action_spec,
reward_spec=reward_spec,
actor_network_cls=actor_network_cls,
critic_network_cls=critic_network_cls,
q_network_cls=q_network_cls,
epsilon_greedy=epsilon_greedy,
use_entropy_reward=use_entropy_reward,
calculate_priority=calculate_priority,
num_critic_replicas=num_critic_replicas,
env=env,
config=config,
critic_loss_ctor=critic_loss_ctor,
target_entropy=target_entropy,
prior_actor_ctor=prior_actor_ctor,
initial_log_alpha=initial_log_alpha,
max_log_alpha=max_log_alpha,
target_update_tau=target_update_tau,
target_update_period=target_update_period,
dqda_clipping=dqda_clipping,
actor_optimizer=actor_optimizer,
critic_optimizer=critic_optimizer,
alpha_optimizer=alpha_optimizer,
checkpoint=checkpoint,
debug_summaries=debug_summaries,
name=name)
if explore:
assert self._act_type == ActionType.Continuous, (
"Only continuous action space is supported for explore mode.")
self._explore = explore
self._explore_delta = explore_delta
self._beta_ub = beta_ub
def _predict_action(self,
observation,
state: SacActionState,
epsilon_greedy=None,
eps_greedy_sampling=False,
explore=False):
"""
Differences between SacAlgorithm._predict_action:
1. Only continuous actions are supported.
2. Add a switch for explore mode where OAC explore policy is constructed
from the target policy (actor_network) and used for action prediction.
"""
new_state = SacActionState()
action_dist, actor_network_state = self._actor_network(
observation, state=state.actor_network)
assert isinstance(action_dist, td.TransformedDistribution), (
"Squashed distribution is expected from actor_network.")
assert isinstance(
action_dist.base_dist, dist_utils.DiagMultivariateNormal
), ("the base distribution should be diagonal multivariate normal.")
normal_dist = action_dist.base_dist
unsquashed_mean = normal_dist.mean
unsquashed_std = normal_dist.stddev
unsquashed_var = normal_dist.variance
new_state = new_state._replace(actor_network=actor_network_state)
def mean_shift_fn(mu, dqda, sigma):
if self._dqda_clipping:
dqda = torch.clamp(dqda, -self._dqda_clipping,
self._dqda_clipping)
norm = torch.sqrt(torch.sum(torch.mul(dqda * dqda, sigma))) + 1e-6
shift = self._explore_delta * torch.mul(sigma, dqda) / norm
return mu + shift
if explore:
critic_action = normal_dist.mean.detach().clone()
critic_action.requires_grad = True
transformed_action = critic_action
with torch.enable_grad():
for transform in action_dist.transforms:
transformed_action = transform(transformed_action)
critics, critic_state = self._critic_networks(
(observation, transformed_action), state=state.critic)
new_state = new_state._replace(critic=critic_state)
if critics.ndim > 2:
critics = critics.squeeze()
assert critics.ndim == 2
q_mean = critics.mean(dim=1)
q_std = torch.abs(critics[:, 0] - critics[:, 1]) / 2.0
q_ub = q_mean + self._beta_ub * q_std
dqda = nest_utils.grad(critic_action, q_ub.sum())
shifted_mean = nest.map_structure(mean_shift_fn, unsquashed_mean,
dqda, unsquashed_var)
normal_dist = dist_utils.DiagMultivariateNormal(
loc=shifted_mean, scale=unsquashed_std)
action_dist = td.TransformedDistribution(
base_distribution=normal_dist,
transforms=action_dist.transforms)
action = dist_utils.rsample_action_distribution(action_dist)
else:
if eps_greedy_sampling:
action = dist_utils.epsilon_greedy_sample(
action_dist, epsilon_greedy)
else:
action = dist_utils.rsample_action_distribution(action_dist)
return action_dist, action, None, new_state
[docs] def rollout_step(self, inputs: TimeStep, state: SacState):
"""Same as SacAlgorithm.rollout_step except that `explore` is set to be
`self._explore` when calling `_predict_action`.
"""
action_dist, action, _, action_state = self._predict_action(
inputs.observation,
state=state.action,
epsilon_greedy=1.0,
eps_greedy_sampling=True,
explore=self._explore)
if self.need_full_rollout_state():
_, critics_state = self._compute_critics(
self._critic_networks, inputs.observation, action,
state.critic.critics)
_, target_critics_state = self._compute_critics(
self._target_critic_networks, inputs.observation, action,
state.critic.target_critics)
critic_state = SacCriticState(
critics=critics_state, target_critics=target_critics_state)
actor_state = critics_state
else:
actor_state = state.actor
critic_state = state.critic
new_state = SacState(
action=action_state, actor=actor_state, critic=critic_state)
return AlgStep(
output=action,
state=new_state,
info=SacInfo(action=action, action_distribution=action_dist))