Source code for alf.optimizers.trusted_updater

# Copyright (c) 2019 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TrustedUpdater."""

from typing import Callable

import torch
import torch.nn as nn

import alf
from alf.utils import math_ops

nest_map = alf.nest.map_structure


[docs]class TrustedUpdater(object): """Adjust variables based on the change calculated by `change_f()` The motivation is that if some quatity changes too much after an SGD update, the SGD step might be too big. We want to shink that step so that the concerned quatity does not change too much. We can also monitor multiple quantities to make sure none of them has sudden big jump. It adjusts variables provided at `__init__` if the change calculated by `change_f` is too big: ``` change = change_f() if change > max_change: var <= old_var + 0.9 * (max_change/change) * (var - old_var) ``` The above procedure is repeated until `change` is not bigger than `max_change`. Note that `change` and `max_change` can be nests of scalars. In this case, the inequality is understood as if any one of the changes is greater than its corresponding max_change. """ def __init__(self, parameters): """Create a TrustedUpdater instance. Args: parameters (list[Parameter]): parameters to be monitored. """ self._variables = parameters assert len(self._variables) > 0 self._prev_variables = [ nn.Parameter(v.clone(), requires_grad=False) for v in parameters ]
[docs] def adjust_step(self, change_f: Callable, max_change): """Adjust `parameters` based change calculated by change_f This function will copy the new values of the variables to a backup to be used for the next call of adjust_step. Args: change_f (Callable): a function calculate a (nested) change based on current variable. max_change (float): (nested) max change allowed. Returns: the initial change before variables are adjusted the number of steps to adjust variables. 0 for no adjustment """ def _adjust_step(ratio): # 0.9 is to prevent infinite loop when `actual_change` is close to # `max_change` r = 0.9 / ratio for var, prev_var in zip(self._variables, self._prev_variables): var.data.copy_(prev_var + r * (var - prev_var)) steps = 0 change0 = change_f() change = change0 ratio = nest_map(lambda c, m: c.abs() / m, change, max_change) ratio = math_ops.max_n(alf.nest.flatten(ratio)) while ratio > 1. and steps < 100: _adjust_step(ratio) change = change_f() ratio = nest_map(lambda c, m: c.abs() / m, change, max_change) ratio = math_ops.max_n(alf.nest.flatten(ratio)) steps += 1 # This suggests something wrong. change cannot be reduced by making # the step smaller. assert steps < 100, ("Something is wrong. change cannot be reduced by " "making the step smaller.") for var, prev_var in zip(self._variables, self._prev_variables): prev_var.data.copy_(var) return change0, steps