# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import functools
import torch
import alf
from alf.algorithms.rl_algorithm import RLAlgorithm
from alf.data_structures import (TimeStep, Experience, LossInfo, namedtuple,
AlgStep, StepType)
from alf.tensor_specs import TensorSpec, BoundedTensorSpec
import alf.utils.common as common
GoalState = namedtuple("GoalState", ["goal"], default_value=())
GoalInfo = namedtuple("GoalInfo", ["goal", "loss"], default_value=())
[docs]@alf.configurable
class RandomCategoricalGoalGenerator(RLAlgorithm):
"""Random Goal Generation Module.
This module generates a random categorical goal for the agent
in the beginning of every episode.
"""
def __init__(self,
observation_spec,
num_of_goals,
name="RandomCategoricalGoalGenerator"):
"""
Args:
observation_spec (nested TensorSpec): representing the observations.
num_of_goals (int): total number of goals the agent can sample from.
name (str): name of the algorithm.
"""
goal_spec = TensorSpec((num_of_goals, ))
train_state_spec = GoalState(goal=goal_spec)
super().__init__(
observation_spec=observation_spec,
action_spec=BoundedTensorSpec(
shape=(num_of_goals, ),
dtype='float32',
minimum=0.,
maximum=1.),
train_state_spec=train_state_spec,
name=name)
self._num_of_goals = num_of_goals
def _generate_goal(self, observation, state):
"""Generate new goals.
Args:
observation (nested Tensor): the observation at the current time step.
state (nested Tensor): state of this goal generator.
Returns:
Tensor: a batch of one-hot goal tensors.
"""
batch_size = alf.nest.get_nest_batch_size(observation)
goals = torch.randint(
high=self._num_of_goals, size=(batch_size, ), dtype=torch.int64)
goals_onehot = torch.nn.functional.one_hot(
goals, self._num_of_goals).to(torch.float32)
return goals_onehot
def _update_goal(self, observation, state, step_type):
"""Update the goal if the episode just beginned; otherwise keep using
the goal in ``state``.
Args:
observation (nested Tensor): the observation at the current time step
state (nested Tensor): state of this goal generator
step_type (StepTyp):
Returns:
Tensor: a batch of one-hot tensors representing the updated goals.
"""
new_goal_mask = torch.unsqueeze((step_type == StepType.FIRST), dim=-1)
generated_goal = self._generate_goal(observation, state)
new_goal = torch.where(new_goal_mask, generated_goal, state.goal)
return new_goal
def _step(self, time_step: TimeStep, state):
"""Perform one step of rollout or prediction.
Note that as ``RandomCategoricalGoalGenerator`` is a non-trainable module,
and it will randomly generate goals for episode beginnings.
Args:
time_step (TimeStep): input time_step data.
state (nested Tensor): consistent with ``train_state_spec``.
Returns:
AlgStep:
- output (Tensor); one-hot goal vectors.
- state (nested Tensor):
- info (GoalInfo): storing any info that will be put into a replay
buffer (if off-policy training is used.
"""
observation = time_step.observation
step_type = time_step.step_type
new_goal = self._update_goal(observation, state, step_type)
return AlgStep(
output=(new_goal, ()),
state=GoalState(goal=new_goal),
info=GoalInfo(goal=new_goal))
[docs] def rollout_step(self, inputs: TimeStep, state):
return self._step(inputs, state)
[docs] def predict_step(self, inputs: TimeStep, state):
return self._step(inputs, state)
[docs] def train_step(self, inputs: TimeStep, state, rollout_info):
"""For off-policy training, the current output goal should be taken from
the goal in ``rollout_info`` (historical goals generated during rollout).
Note that we cannot take the goal from ``state`` and pass it down because
the first state might be a zero vector. And we also cannot resample
the goal online because that might be inconsistent with the sampled
experience trajectory.
Args:
inputs (TimeStep): the experience data.
state (nested Tensor):
rollout_info (GoalInfo):
Returns:
AlgStep:
- output (Tensor); one-hot goal vectors
- state (nested Tensor):
- info (GoalInfo): for training.
"""
goal = rollout_info.goal
return AlgStep(
output=(goal, ()), state=state, info=GoalInfo(goal=goal))
[docs] def calc_loss(self, info: GoalInfo):
return LossInfo()