# Copyright (c) 2022 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implicit Q-Learning Algorithm."""
import numpy as np
import functools
import torch
import alf
from alf.algorithms.config import TrainerConfig
from alf.algorithms.off_policy_algorithm import OffPolicyAlgorithm
from alf.algorithms.one_step_loss import OneStepTDLoss
from alf.data_structures import TimeStep, LossInfo, namedtuple
from alf.data_structures import AlgStep, StepType
from alf.nest import nest
from alf.networks import ActorDistributionNetwork, CriticNetwork
from alf.networks import ValueNetwork
from alf.tensor_specs import TensorSpec, BoundedTensorSpec
from alf.utils import common, dist_utils, math_ops
IqlActionState = namedtuple(
"IqlActionState", ["actor_network", "critic"], default_value=())
IqlCriticState = namedtuple("IqlCriticState", ["critics", "target_critics"])
IqlState = namedtuple(
"IqlState", ["action", "actor", "critic"], default_value=())
IqlCriticInfo = namedtuple("IqlCriticInfo",
["critics", "target_value", "value"])
IqlActorInfo = namedtuple("IqlActorInfo", ["actor_loss"], default_value=())
IqlInfo = namedtuple(
"IqlInfo", [
"reward", "step_type", "discount", "action", "action_distribution",
"actor", "critic"
],
default_value=())
IqlLossInfo = namedtuple('IqlLossInfo', ('actor', 'critic'))
[docs]@alf.configurable
class IqlAlgorithm(OffPolicyAlgorithm):
r"""Implicit q-learning algorithm (IQL).
IQL is an offline reinforcement learning method. The idea is
that instead of constraining the critic network or policy to avoid the
value function extrapolation issue, IQL conducts learning using only
in-sample data, thus voiding the issues when querying the critic network
with out-of-distribution actions, a problem commonly faced in offline RL.
Reference:
::
Kostrikov, et al. "Offline Reinforcement Learning with Implicit Q-Learning",
arXiv:2110.06169
"""
def __init__(self,
observation_spec,
action_spec: BoundedTensorSpec,
reward_spec=TensorSpec(()),
actor_network_cls=ActorDistributionNetwork,
critic_network_cls=CriticNetwork,
v_network_cls=ValueNetwork,
reward_weights=None,
epsilon_greedy=None,
calculate_priority=False,
num_critic_replicas=2,
env=None,
config: TrainerConfig = None,
critic_loss_ctor=None,
target_update_tau=0.05,
target_update_period=1,
temperature=1.0,
actor_optimizer=None,
critic_optimizer=None,
value_optimizer=None,
expectile=0.8,
max_exp_advantage=100,
checkpoint=None,
debug_summaries=False,
name="IqlAlgorithm"):
"""
Args:
observation_spec (nested TensorSpec): representing the observations.
action_spec (BoundedTensorSpec): representing the actions. Only
continuous action is supported currently.
reward_spec (TensorSpec): a rank-1 or rank-0 tensor spec representing
the reward(s).
actor_network_cls (Callable): is used to construct the actor network.
The constructed actor network will be called
to sample continuous actions. All of its output specs must be
continuous. Discrete actor network is not supported.
critic_network_cls (Callable): is used to construct critic network.
v_network_cls (Callable): is used to construct a value network.
for estimating the expectile of q values.
reward_weights (None|list[float]): this is only used when the reward is
multidimensional. In that case, the weighted sum of the q values
is used for training the actor if reward_weights is not None.
Otherwise, the sum of the q values is used.
epsilon_greedy (float): a floating value in [0,1], representing the
chance of action sampling instead of taking argmax. This can
help prevent a dead loop in some deterministic environment like
Breakout. Only used for evaluation. If None, its value is taken
from ``config.epsilon_greedy`` and then
``alf.get_config_value(TrainerConfig.epsilon_greedy)``.
calculate_priority (bool): whether to calculate priority. This is
only useful if priority replay is enabled.
num_critic_replicas (int): number of critics to be used. Default is 2.
This is only applied for critic networks. The value network is
not replicated.
env (Environment): The environment to interact with. ``env`` is a
batched environment, which means that it runs multiple simulations
simultateously. ``env` only needs to be provided to the root
algorithm.
config (TrainerConfig): config for training. It only needs to be
provided to the algorithm which performs ``train_iter()`` by
itself.
critic_loss_ctor (None|OneStepTDLoss|MultiStepLoss): a critic loss
constructor. If ``None``, a default ``OneStepTDLoss`` will be used.
target_update_tau (float): Factor for soft update of the target
networks.
target_update_period (int): Period for soft update of the target
networks.
temperature (float): the hyper-parameter for scaling the advantages.
It corresponds to 1/beta in Eqn.(7) of the paper.
actor_optimizer (torch.optim.optimizer): The optimizer for actor.
critic_optimizer (torch.optim.optimizer): The optimizer for critic.
value_optimizer (torch.optim.optimizer): The optimizer for value network.
expectile (float): the expectile value for value learning.
max_exp_advantage (float): clamp the exponentiated advantages with
this value before being applied to weight the actor loss.
checkpoint (None|str): a string in the format of "prefix@path",
where the "prefix" is the multi-step path to the contents in the
checkpoint to be loaded. "path" is the full path to the checkpoint
file saved by ALF. Refer to ``Algorithm`` for more details.
debug_summaries (bool): True if debug summaries should be created.
name (str): The name of this algorithm.
"""
self._num_critic_replicas = num_critic_replicas
self._calculate_priority = calculate_priority
if epsilon_greedy is None:
epsilon_greedy = alf.utils.common.get_epsilon_greedy(config)
self._epsilon_greedy = epsilon_greedy
critic_networks, actor_network, v_network = self._make_networks(
observation_spec, action_spec, reward_spec, actor_network_cls,
critic_network_cls, v_network_cls)
action_state_spec = IqlActionState(
actor_network=actor_network.state_spec, critic=())
super().__init__(
observation_spec=observation_spec,
action_spec=action_spec,
reward_spec=reward_spec,
train_state_spec=IqlState(
action=action_state_spec,
actor=critic_networks.state_spec,
critic=IqlCriticState(
critics=critic_networks.state_spec,
target_critics=critic_networks.state_spec)),
predict_state_spec=IqlState(action=action_state_spec),
reward_weights=reward_weights,
env=env,
config=config,
checkpoint=checkpoint,
debug_summaries=debug_summaries,
name=name)
if actor_optimizer is not None and actor_network is not None:
self.add_optimizer(actor_optimizer, [actor_network])
if critic_optimizer is not None:
self.add_optimizer(critic_optimizer, [critic_networks])
if value_optimizer is not None:
self.add_optimizer(value_optimizer, [v_network])
self._temperature = temperature
self._actor_network = actor_network
self._critic_networks = critic_networks
self._target_critic_networks = self._critic_networks.copy(
name='target_critic_networks')
self._v_network = v_network
if critic_loss_ctor is None:
critic_loss_ctor = OneStepTDLoss
critic_loss_ctor = functools.partial(
critic_loss_ctor, debug_summaries=debug_summaries)
# Have different names to separate their summary curves
self._critic_losses = []
for i in range(num_critic_replicas):
self._critic_losses.append(
critic_loss_ctor(name="critic_loss%d" % (i + 1)))
self._update_target = common.TargetUpdater(
models=[self._critic_networks],
target_models=[self._target_critic_networks],
tau=target_update_tau,
period=target_update_period)
self._expectile = expectile
self._max_exp_advantage = max_exp_advantage
def _make_networks(self, observation_spec, action_spec, reward_spec,
continuous_actor_network_cls, critic_network_cls,
v_network_cls):
def _make_parallel(net):
return net.make_parallel(
self._num_critic_replicas * reward_spec.numel)
def _check_spec_equal(spec1, spec2):
assert nest.flatten(spec1) == nest.flatten(spec2), (
"Unmatched action specs: {} vs. {}".format(spec1, spec2))
actor_network = continuous_actor_network_cls(
input_tensor_spec=observation_spec, action_spec=action_spec)
critic_network = critic_network_cls(
input_tensor_spec=(observation_spec, action_spec))
critic_networks = _make_parallel(critic_network)
v_network = v_network_cls(input_tensor_spec=observation_spec)
return critic_networks, actor_network, v_network
def _predict_action(self,
observation,
state: IqlActionState,
epsilon_greedy=None,
eps_greedy_sampling=False,
rollout=False):
new_state = IqlActionState()
continuous_action_dist, actor_network_state = self._actor_network(
observation, state=state.actor_network)
new_state = new_state._replace(actor_network=actor_network_state)
if eps_greedy_sampling:
continuous_action = dist_utils.epsilon_greedy_sample(
continuous_action_dist, epsilon_greedy)
else:
continuous_action = dist_utils.rsample_action_distribution(
continuous_action_dist)
action_dist = continuous_action_dist
action = continuous_action
return action_dist, action, new_state
[docs] def predict_step(self, inputs: TimeStep, state: IqlState):
_, action, action_state = self._predict_action(
inputs.observation,
state=state.action,
epsilon_greedy=self._epsilon_greedy,
eps_greedy_sampling=True)
return AlgStep(output=action, state=IqlState(action=action_state))
[docs] def rollout_step(self, inputs: TimeStep, state: IqlState):
"""``rollout_step()`` basically predicts actions like what is done by
``predict_step()``. Additionally, if states are to be stored a in replay
buffer, then this function also call ``_critic_networks`` and
``_target_critic_networks`` to maintain their states.
"""
action_dist, action, _, action_state = self._predict_action(
inputs.observation,
state=state.action,
epsilon_greedy=1.0,
eps_greedy_sampling=True,
rollout=True)
if self.need_full_rollout_state():
_, critics_state = self._compute_critics(
self._critic_networks, inputs.observation, action,
state.critic.critics)
_, target_critics_state = self._compute_critics(
self._target_critic_networks, inputs.observation, action,
state.critic.target_critics)
critic_state = IqlCriticState(
critics=critics_state, target_critics=target_critics_state)
actor_state = critics_state
else:
actor_state = state.actor
critic_state = state.critic
new_state = IqlState(
action=action_state, actor=actor_state, critic=critic_state)
return AlgStep(
output=action,
state=new_state,
info=IqlInfo(action=action, action_distribution=action_dist))
def _compute_critics(self,
critic_net,
observation,
action,
critics_state,
replica_min=True,
apply_reward_weights=True):
observation = (observation, action)
# critics shape [B, replicas]
critics, critics_state = critic_net(observation, state=critics_state)
# For multi-dim reward, do
# [B, replicas * reward_dim] -> [B, replicas, reward_dim]
# For scalar reward, do nothing
if self.has_multidim_reward():
remaining_shape = critics.shape[2:]
critics = critics.reshape(-1, self._num_critic_replicas,
*self._reward_spec.shape,
*remaining_shape)
if replica_min:
if self.has_multidim_reward():
sign = self.reward_weights.sign()
critics = (critics * sign).min(dim=1)[0] * sign
else:
critics = critics.min(dim=1)[0]
if apply_reward_weights and self.has_multidim_reward():
critics = critics * self.reward_weights
critics = critics.sum(dim=-1)
return critics, critics_state
def _actor_train_step(self, inputs: TimeStep, state, action_distribution,
v_value, rollout_info):
# IQL uses target critic network for computing the value learning target
q_value, critics_state = self._compute_critics(
self._target_critic_networks, inputs.observation,
rollout_info.action, state)
weight = torch.exp((q_value - v_value) / self._temperature)
weight = torch.clamp(weight, max=self._max_exp_advantage)
# log_pi_data: the log probability computed with the action from dataset
log_pi_data = dist_utils.compute_log_probability(
action_distribution, rollout_info.action)
weighted_log_pi = -weight.detach() * log_pi_data
actor_loss = weighted_log_pi
actor_info = LossInfo(
loss=actor_loss, extra=IqlActorInfo(actor_loss=actor_loss))
return critics_state, actor_info
def _critic_train_step(self, inputs: TimeStep, state: IqlCriticState,
rollout_info: IqlInfo):
# use dataset action for Q learning
critics, critics_state = self._compute_critics(
self._critic_networks,
inputs.observation,
rollout_info.action,
state.critics,
replica_min=False,
apply_reward_weights=False)
# use value network (there is no target value network), also no replica
# use an upper quantile
# calculate the state value, which will be used in two places:
# 1) used for constructing the the target value for q-learning
# 2) used for training the value network using expectile loss over the
# difference with respect to the prediction of target q-network
value, critics_state = self._v_network(
inputs.observation, state=critics_state)
value = value.squeeze(-1)
# use dataset state action pair for training
target_value, target_critics_state = self._compute_critics(
self._target_critic_networks,
inputs.observation,
rollout_info.action,
state.target_critics,
apply_reward_weights=False)
state = IqlCriticState(
critics=critics_state, target_critics=target_critics_state)
info = IqlCriticInfo(
critics=critics, target_value=target_value, value=value)
return state, info
[docs] def train_step(self, inputs: TimeStep, state: IqlState,
rollout_info: IqlInfo):
self._training_started = True
(action_distribution, action, action_state) = self._predict_action(
inputs.observation, state=state.action)
critic_state, critic_info = self._critic_train_step(
inputs, state.critic, rollout_info)
actor_state, actor_loss = self._actor_train_step(
inputs, state.actor, action_distribution, critic_info.value,
rollout_info)
state = IqlState(
action=action_state, actor=actor_state, critic=critic_state)
info = IqlInfo(
reward=inputs.reward,
step_type=inputs.step_type,
discount=inputs.discount,
action=rollout_info.action,
action_distribution=action_distribution,
actor=actor_loss,
critic=critic_info)
return AlgStep(action, state, info)
[docs] def after_update(self, root_inputs, info: IqlInfo):
self._update_target()
[docs] def calc_loss(self, info: IqlInfo):
critic_loss = self._calc_critic_loss(info)
actor_loss = info.actor
loss = math_ops.add_ignore_empty(actor_loss.loss, critic_loss.loss)
return LossInfo(
loss=loss,
priority=critic_loss.priority,
extra=IqlLossInfo(
actor=actor_loss.extra, critic=critic_loss.extra))
def _calc_critic_loss(self, info: IqlInfo):
def exp_loss(diff, expectile):
weight = torch.where(diff > 0, expectile, (1 - expectile))
return weight * (diff**2)
critic_info = info.critic
critic_losses = []
for i, l in enumerate(self._critic_losses):
critic_losses.append(
l(info=info,
value=critic_info.critics[:, :, i, ...],
target_value=critic_info.value.detach()).loss
) # use ``critic_info.value`` for constructing target value
critic_loss = math_ops.add_n(critic_losses)
# q_target_critic_min - v
value_diff = critic_info.target_value.detach() - critic_info.value
value_loss = exp_loss(value_diff, self._expectile)
if self._debug_summaries and alf.summary.should_record_summaries():
with alf.summary.scope(self._name):
alf.summary.scalar("value_loss", value_loss.mean())
critic_loss = critic_loss + value_loss
if self._calculate_priority:
valid_masks = (info.step_type != StepType.LAST).to(torch.float32)
valid_n = torch.clamp(valid_masks.sum(dim=0), min=1.0)
priority = (
(critic_loss * valid_masks).sum(dim=0) / valid_n).sqrt()
else:
priority = ()
return LossInfo(
loss=critic_loss,
priority=priority,
extra=critic_loss / float(self._num_critic_replicas))
def _trainable_attributes_to_ignore(self):
return ['_target_critic_networks']