Source code for alf.utils.spec_utils

# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Collection of spec utility functions."""

import numpy as np
import torch
from typing import Iterable

import alf.nest as nest
from alf.nest.utils import get_outer_rank
from alf.tensor_specs import TensorSpec, BoundedTensorSpec
from . import dist_utils
from alf.utils.tensor_utils import BatchSquash


[docs]def spec_means_and_magnitudes(spec: BoundedTensorSpec): """Get the center and magnitude of the ranges for the input spec. Args: spec (BoundedTensorSpec): the spec used to compute mean and magnitudes. Returns: spec_means (Tensor): the mean value of the spec bound. spec_magnitudes (Tensor): the magnitude of the spec bound. """ spec_means = (spec.maximum + spec.minimum) / 2.0 spec_magnitudes = (spec.maximum - spec.minimum) / 2.0 return torch.as_tensor(spec_means).to( spec.dtype), torch.as_tensor(spec_magnitudes).to(spec.dtype)
[docs]def scale_to_spec(tensor, spec: BoundedTensorSpec): """Shapes and scales a batch into the given spec bounds. Args: tensor: A tensor with values in the range of [-1, 1]. spec: (BoundedTensorSpec) to use for scaling the input tensor. Returns: A batch scaled the given spec bounds. """ bs = BatchSquash(get_outer_rank(tensor, spec)) tensor = bs.flatten(tensor) means, magnitudes = spec_means_and_magnitudes(spec) tensor = means + magnitudes * tensor tensor = bs.unflatten(tensor) return tensor
[docs]def clip_to_spec(value, spec: BoundedTensorSpec): """Clips value to a given bounded tensor spec. Args: value: (tensor) value to be clipped. spec: (BoundedTensorSpec) spec containing min and max values for clipping. Returns: clipped_value: (tensor) `value` clipped to be compatible with `spec`. """ return torch.max( torch.min(value, torch.as_tensor(spec.maximum)), torch.as_tensor(spec.minimum))
[docs]def zeros_from_spec(nested_spec, batch_size): """Create nested zero Tensors or Distributions. A zero tensor with shape[0]=`batch_size is created for each TensorSpec and A distribution with all the parameters as zero Tensors is created for each DistributionSpec. Args: nested_spec (nested TensorSpec or DistributionSpec): batch_size (int|tuple|list): batch size/shape added as the first dimension to the shapes in TensorSpec Returns: nested Tensor or Distribution """ if isinstance(batch_size, Iterable): shape = batch_size else: shape = [batch_size] def _zero_tensor(spec): return spec.zeros(shape) param_spec = dist_utils.to_distribution_param_spec(nested_spec) params = nest.map_structure(_zero_tensor, param_spec) return dist_utils.params_to_distributions(params, nested_spec)
[docs]def is_same_spec(spec1, spec2): """Whether two nested specs are same. Args: spec1 (nested TensorSpec): the first spec spec2 (nested TensorSpec): the second spec Returns: bool """ try: nest.assert_same_structure(spec1, spec2) except: return False same = nest.map_structure(lambda s1, s2: s1 == s2, spec1, spec2) return all(nest.flatten(same))