Source code for alf.algorithms.ppg.ppg_utils

# Copyright (c) 2021 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Import annotations to enable type hints of PPGTrainInfo inside PPGTrainInfo
from __future__ import annotations
from typing import Optional

from alf.data_structures import namedtuple, TimeStep, AlgStep
from alf.utils import common, dist_utils

from .disjoint_policy_value_network import DisjointPolicyValueNetwork

# Data structure to store the information produced by agent
# interacting with the environment.
PPGRolloutInfo = namedtuple(
    'PPGRolloutInfo',
    [
        # produced by the policy head
        'action_distribution',
        # Sampled from the action distribution produced by the policy head
        'action',
        # Log probability of the action at the rollout time
        'log_prob',
        # estimated value function by the value head
        'value',
        # estimated value function by the auxiliary value head
        'aux',
        'step_type',
        'discount',
        'reward',
        'reward_weights',
    ],
    default_value=())


[docs]class PPGTrainInfo( namedtuple( 'PPGTrainInfo', PPGRolloutInfo._fields + ('rollout_action_distribution', 'rollout_value', 'rollout_log_prob'), default_value=())): """Data structure that stores extra derived information for training in addition to the original rollout information. Such extra information is derived during training updates and used across calls to ``train_step()``. It is designed as a separate class (as opposite to be merged into PPGRolloutInfo) becase we want to make it explicit about what are derived compared to the rollout information during training. """
[docs] def absorbed(self, rollout_info: PPGRolloutInfo) -> PPGTrainInfo: """Combines the PPGTrainInfo and the PPGRolloutInfo. This function generate a new PPGTrainInfo instead of updating ``self`` in place. In ``train_step()`, we would like to keep the derived information in PPGTrainInfo while updating most of the shared fields (with PPGRolloutInfo) from evaluation of the updated network. This function makes it easy to do that. Args: rollout_info (PPGRolloutInfo): the result of rollout or evaluation that needs to be combined with ``self`` Returns: A new PPGTrainInfo that combines the useful part from both parties. """ return self._replace( step_type=rollout_info.step_type, reward=rollout_info.reward, discount=rollout_info.discount, action_distribution=rollout_info.action_distribution, value=rollout_info.value, aux=rollout_info.aux, reward_weights=rollout_info.reward_weights)
[docs]def ppg_network_forward(network: DisjointPolicyValueNetwork, inputs: TimeStep, state, require_aux: bool = True, epsilon_greedy: Optional[float] = None) -> AlgStep: """Evaluates the network forward pass for roll out or training The signature mimics ``rollout_step()`` of ``Algorithm`` completedly. Args: network: the network whose forward pass is to be performed. inputs: carries the observation that is needed as input to the network. state (nested Tesnor): carries the state for RNN-based network require_aux: whether to compute and return auxiliary estimation. See DisjointPolicyValueNetwork.forward() for details. epsilon_greedy: if set to None, the action will be sampled strictly based on the action distribution. If set to a value in [0, 1], epsilon-greedy sampling will be used to sample the action from the action distribution, and the float value determines the chance of action sampling instead of taking argmax. """ (action_distribution, value, aux), state = network( inputs.observation, state=state, require_aux=require_aux) if epsilon_greedy is not None: action = dist_utils.epsilon_greedy_sample(action_distribution, epsilon_greedy) log_prob = () else: action, log_prob = dist_utils.sample_action_distribution( action_distribution, return_log_prob=True) return AlgStep( output=action, state=state, info=PPGRolloutInfo( action_distribution=action_distribution, action=common.detach(action), log_prob=common.detach(log_prob), value=value, aux=aux, step_type=inputs.step_type, discount=inputs.discount, reward=inputs.reward, reward_weights=()))