Source code for alf.optimizers.utils

# Copyright (c) 2022 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
from typing import Any

import alf
from alf.utils.averager import ScalarEMAverager
from alf.utils.tensor_utils import global_norm, clip_by_global_norm


[docs]def get_opt_arg(p: nn.Parameter, argname: str, default: Any = None): """Get parameter specific optimizer arguments. Args: p: the parameter argname: name of the argument default: the default value Returns: The parameter specific value if it is found, otherwise default """ opt_args = getattr(p, 'opt_args', None) if opt_args is None: return default value = opt_args.get(argname, None) return default if value is None else value
[docs]@alf.configurable class GradientNoiseScaleEstimator(nn.Module): r"""Implement the simple Gradient Noise Scale estimator as detailed in Appendix A, "An Empirical Model of Large-Batch Training", McCandlish et al., `arXiv <https://arxiv.org/pdf/1812.06162.pdf>`_, 2018. The simplified GNS is defined as: .. math:: B_{simple} = \frac{tr(\Sigma(\theta))}{|G(\theta)|^2}, where :math:`\Sigma` is the per-sample covariance matrix defined as .. math:: \Sigma(\theta) = cov_{x\sim p} (\Nabla_{\theta} L_x(\theta)), and :math:`G(\theta)` is the true gradient given the entire data distribution. Generally, GNS indicates the noise-to-signal value of SGD. The authors suggest that we should choose a batch size close to GNS in order to average out the noise in the gradient. In other words, GNS is positively correlated to the current gradient descent difficulty. We would expect a high GNS for a difficult learning task, especially when different training samples generate opposite gradient directions. .. note:: You can turn on this estimator in ``TrainerConfig``. However, this will increase the back-propagation overhead. Note that the *expectation* of the estimated GNS is independent with the batch size in theory, but does depend on the learning rate. A good practice of using this estimator given a learning rate is to make sure: 1. the learning rate is reasonable. If it's too large, then GNS is unstable. 1. that the batch size is large enough (smaller variance), and 2. the batch data can represent samples from the true data distribution. For example, if your batch is too large but the replay buffer is too small, then the estimate won't make sense (consider increasing the ``initial_collect_steps``). We also provide an alternative way of estimating GNS. Given the gradients of two sampled batches :math:`G_{est1}` and :math:`G_{est2}`, we have .. math:: \begin{array}{l} \alpha\triangleq \mathbb{E}[<G_{est1}\circ G_{est2}>] = |G|^2 \\ \beta\triangleq\mathbb{E}[\frac{|G_{est1}|^2 + |G_{est2}|^2}{2}] = \frac{1}{B}tr[\Sigma] + |G|^2 \\ \end{array} Then we can maintain a moving average of :math:`\bar{\alpha}` and :math:`\bar{\beta}`, and use :math:`(\frac{\bar{\beta}}{\bar{\alpha}}-1)B` as the estimated GNS. """ def __init__(self, batch_size_ratio: float = 0.1, update_rate: float = 0.001, gradient_norm_clip: float = None, mode: str = "alternative", name: str = "GNSEstimator"): """ Args: batch_size_ratio: the portion of a batch to be used as a "smaller" batch. In theory, another smaller batch should be sampled *independently* from the data distribution. However, for simplicity, this estimator samples the smaller batch from a batch and uses the remaining as the larger batch. So this ratio should be small (<0.5). If the ratio is too small, the calculated smaller batch size will be clipped at 1. update_rate: the update rate for computing moving averages of the quantities needed by GNS. Generally, a smaller value (slower update) makes the estimated GNS more biased (because quantities at different training steps are averaged) while a larger value (quicker update) makes it have more variances. gradient_norm_clip: a clipping value for global gradient norm. If None, no clipping is performed. Usually, a clipping value is required for a stable GNS estimate. Depending on how stable the GNS is estimated, this value could also suggest a clipping norm for the optimizer. mode: either "paper" or "alternative". "paper" uses the calculation in the paper. "alternative" is the default mode as its calculation is easier to understand. name: """ super().__init__() self._name = name assert mode in ["paper", "alternative"] if mode == "paper": assert 0 < batch_size_ratio < 0.5 else: batch_size_ratio = 0.5 self._mode = mode self._batch_size_ratio = batch_size_ratio self._grad_norm_clip = gradient_norm_clip self._gradient_norm_averager = ScalarEMAverager( update_rate=update_rate) self._var_trace_averager = ScalarEMAverager(update_rate=update_rate) self.register_buffer('_last_valid_gns', torch.zeros(())) def _calculate_gradient_norm(self, loss: torch.Tensor, tensors: alf.nest.NestedTensor): grads = alf.nest.utils.grad(tensors, loss.mean(), retain_graph=True) if self._grad_norm_clip is not None: grads, _ = clip_by_global_norm(grads, self._grad_norm_clip) norm2 = global_norm(grads) grads = torch.cat([g.reshape(-1) for g in alf.nest.flatten(grads)]) return norm2**2, grads
[docs] def forward(self, loss: torch.Tensor, tensors: alf.nest.NestedTensor): """Given a loss tensor and a nest of tensors, return the estimated GNS. Args: loss: a loss tensor *before* taking the mean. Each entry of the tensor represents an individual loss on a single training sample. Ideally, the samples used for computing these losses should be sampled *with* replacement independently. The loss can have a shape of either ``[T,B]`` or ``[B]``. The estimate will be more stable if ``B`` is large and the batch could represent samples from the data distribution well. tensors: a nest of tensors whose gradients are considered Returns: gns: the estimated gradient noise scale (a scalar). A smaller value means more effective grad steps. """ assert loss.ndim in [1, 2], "loss must be a rank-1 or -2 tensor!" B = loss.shape[-1] shuffled_loss = loss[..., torch.randperm(B)] b = max(1, int(B * self._batch_size_ratio)) B -= b B_norm2, B_grads = self._calculate_gradient_norm( shuffled_loss[..., b:], tensors) b_norm2, b_grads = self._calculate_gradient_norm( shuffled_loss[..., :b], tensors) if self._mode == "paper": gradient_norm = (B * B_norm2 - b * b_norm2) / (B - b) var_trace = (b_norm2 - B_norm2) / (1. / b - 1. / B) else: assert B == b, "Check if the batch size is even!" var_trace = (B_norm2 + b_norm2) / 2. gradient_norm = (B_grads * b_grads).sum() avg_grad_norm = self._gradient_norm_averager.average(gradient_norm) avg_var_trace = self._var_trace_averager.average(var_trace) simple_noise_scale = avg_var_trace / (avg_grad_norm + 1e-8) if self._mode == "alternative": simple_noise_scale = (simple_noise_scale - 1) * B if simple_noise_scale < 0: # In theory GNS should be non-negative. If the current estimate is # negative, then we simply reuse the last estimate. simple_noise_scale = self._last_valid_gns else: self._last_valid_gns = torch.clone(simple_noise_scale) return simple_noise_scale.detach()