Source code for alf.algorithms.decoding_algorithm

# Copyright (c) 2019 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Decoding algorithm."""

import torch

import alf
from alf.algorithms.algorithm import Algorithm
from alf.data_structures import AlgStep, LossInfo
from alf.networks import Network
from alf.utils.math_ops import sum_to_leftmost


[docs]@alf.configurable class DecodingAlgorithm(Algorithm): """Generic decoding algorithm.""" def __init__(self, decoder: Network, loss=torch.nn.MSELoss(reduction='none'), loss_weight=1.0, name="DecodingAlgorithm"): """ Args: decoder (Network): network for decoding target from input. loss (Callable): loss function with signature ``loss(y_pred, y_true)``. Note that it should not reduce to a scalar. It should at least keep the batch dimension in the returned loss. loss_weight (float): weight for the loss. """ super(DecodingAlgorithm, self).__init__( train_state_spec=decoder.state_spec, name=name) self._decoder = decoder self._loss = loss self._loss_weight = loss_weight
[docs] def train_step(self, inputs, state=(), rollout_info=None): """Train one step. Args: inputs (tuple): tuple of (input, target) state (nested Tensor): network state for ``decoder`` Returns: AlgStep: - output: decoding result - state: rnn state from ``decoder`` - info: loss of decoding """ input, target = inputs pred, state = self._decoder(input, state=state) assert pred.shape == target.shape loss = self._loss(pred, target) assert loss.ndim > 0, "`loss` should return a tensor with batch dimension" # reduce to (B,) loss = sum_to_leftmost(loss, 1) return AlgStep( output=pred, state=state, info=LossInfo(loss=self._loss_weight * loss))