Source code for alf.algorithms.particle_vi_algorithm

# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A generic generator."""

import numpy as np
import torch

import alf
from alf.algorithms.algorithm import Algorithm
from alf.algorithms.mi_estimator import MIEstimator
from alf.algorithms.generator import CriticAlgorithm
from alf.data_structures import AlgStep, LossInfo, namedtuple
import alf.nest as nest
from alf.networks import Network, EncodingNetwork
from alf.tensor_specs import TensorSpec
from alf.utils import common, math_ops
from alf.utils.averager import AdaptiveAverager


[docs]@alf.configurable class ParVIAlgorithm(Algorithm): """ParVIAlgorithm ParVIAlgorithm maintains a set of particles that keep chasing some target distribution. Two particle-based variational inference (par_vi) methods are implemented: 1. Stein Variational Gradient Descent (SVGD): Liu, Qiang, and Dilin Wang. "Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm." NIPS. 2016. 2. Wasserstein Particle-based VI with Smooth Functions (GFSF): Liu, Chang, et al. "Understanding and accelerating particle-based variational inference." International Conference on Machine Learning. 2019. """ def __init__(self, particle_dim, num_particles=10, entropy_regularization=1., par_vi="gfsf", critic_input_dim=None, critic_hidden_layers=(100, 100), critic_l2_weight=10., critic_iter_num=2, critic_use_bn=True, critic_optimizer=None, optimizer=None, debug_summaries=False, name="ParVIAlgorithm"): r"""Create a ParVIAlgorithm. Args: particle_dim (int): dimension of the particles. num_particles (int): number of particles. entropy_regularization (float): weight of the repulsive term in par_vi. par_vi (string): par_vi methods, options are [``svgd``, ``gfsf``, ``None``], * svgd: empirical expectation of SVGD is evaluated by reusing the same batch of particles. * gfsf: wasserstein gradient flow with smoothed functions. It involves a kernel matrix inversion, so computationally more expensive, but in some cases the convergence seems faster than svgd approaches. critic_input_dim (int): dimension of critic input, used for ``minmax``. critic_hidden_layers (tuple): sizes of hidden layers of the critic, used for ``minmax``. critic_l2_weight (float): weight of L2 regularization in training the critic, used for ``minmax``. critic_iter_num (int): number of critic updates for each generator train_step, used for ``minmax``. critic_use_bn (book): whether use batch norm for each layers of the critic, used for ``minmax``. critic_optimizer (torch.optim.Optimizer): Optimizer for training the critic, used for ``minmax``. optimizer (torch.optim.Optimizer): (optional) optimizer for training name (str): name of this generator """ super().__init__( optimizer=optimizer, debug_summaries=debug_summaries, name=name) self._particle_dim = particle_dim self._num_particles = num_particles self._entropy_regularization = entropy_regularization self._particles = None self._par_vi = par_vi if par_vi == 'gfsf': self._grad_func = self._gfsf_grad elif par_vi == 'svgd': self._grad_func = self._svgd_grad elif par_vi == 'minmax': self._grad_func = self._minmax_grad if critic_input_dim is None: critic_input_dim = particle_dim self._critic_iter_num = critic_iter_num self._critic_l2_weight = critic_l2_weight if critic_optimizer is None: critic_optimizer = alf.optimizers.Adam(lr=1e-3) self._critic = CriticAlgorithm( TensorSpec(shape=(critic_input_dim, )), hidden_layers=critic_hidden_layers, use_bn=critic_use_bn, optimizer=critic_optimizer) elif par_vi == None: self._grad_func = self._ml_grad else: raise ValueError("Unsupported par_vi method: %s" % par_vi) self._kernel_width_averager = AdaptiveAverager( tensor_spec=TensorSpec(shape=())) self._particles = torch.nn.Parameter( torch.randn(num_particles, particle_dim, requires_grad=True)) @property def num_particles(self): return self._num_particles @property def particles(self): return self._particles
[docs] def predict_step(self, state=None): """Generate outputs given inputs. Args: state: not used Returns: AlgStep: - output (Tensor): shape is ``[num_particles, output_dim]`` - state: not used """ return AlgStep(output=self.particles, state=(), info=())
[docs] def train_step(self, loss_func, transform_func=None, entropy_regularization=None, loss_mask=None, state=None): """ Args: loss_func (Callable): loss_func(loss_inputs) returns a Tensor or namedtuple of tensors with field `loss`, which is a Tensor of shape [num_particles] a loss term for optimizing the generator. transform_func (Callable): tranform functoin on particles. Used in function value based par_vi, where each particle represents parameters of a neural network function. It is call by transform_func(particles) which returns the following, * outputs: outputs of network parameterized by particles evaluated on predifined training batch. * extra_outputs: outputs of network parameterized by particles evaluated on additional sampled data. entropy_regularization (float): weight of the repulsive term in par_vi. If None, use self._entropy_regularization. loss_mask (Tensor): mask indicating which samples are valid for loss propagation. state: not used Returns: AlgStep: - output (Tensor): shape is ``[num_particles, dim]`` - state: not used - info (LossInfo): loss """ if entropy_regularization is None: entropy_regularization = self._entropy_regularization loss, loss_propagated = self._grad_func( self.particles, loss_func, entropy_regularization, transform_func) if loss_mask is not None: loss_propagated = loss_propagated * loss_mask return AlgStep( output=self.particles, state=(), info=LossInfo(loss=loss_propagated, extra=loss))
def _kernel_width(self, dist): """Update kernel_width averager and get latest kernel_width. """ if dist.ndim > 1: dist = torch.sum(dist, dim=-1) assert dist.ndim == 1, "dist must have dimension 1 or 2." width, _ = torch.median(dist, dim=0) width = width / np.log(len(dist)) self._kernel_width_averager.update(width) return self._kernel_width_averager.get() def _rbf_func(self, x, y=None): r""" Compute the rbf kernel and its gradient w.r.t. first entry :math:`K(x, y), \nabla_x K(x, y)`, used by svgd_grad. Args: x (Tensor): set of N particles, shape (Nx x W), where W is the dimenseion of each particle y (Tensor): set of N particles, shape (Ny x W), where W is the dimenseion of each particle. If y is None, treat y=x. Returns: :math:`K(x, y)` (Tensor): the RBF kernel of shape (Nx x Ny) :math:`\nabla_x K(x, y)` (Tensor): the derivative of RBF kernel of shape (Nx x Ny x W) """ Nx, Dx = x.shape if y is None: y = x else: Ny, Dy = y.shape assert Dx == Dy diff = x.unsqueeze(1) - y.unsqueeze(0) # [Nx, Ny, W] dist_sq = torch.sum(diff**2, -1) # [Nx, Ny] h, _ = torch.median(dist_sq.view(-1), dim=0) if h == 0.: h = torch.ones_like(h) else: h = h / max(np.log(Nx), 1.) kappa = torch.exp(-dist_sq / h) # [Nx, Ny] kappa_grad = -2 * kappa.unsqueeze(-1) * diff / h # [Nx, Ny, W] return kappa, kappa_grad def _score_func(self, x, alpha=1e-5): r""" Compute the stein estimator of the score function :math:`\nabla\log q = -(K + \alpha I)^{-1}\nabla K`, used by gfsf_grad. Args: x (Tensor): set of N particles, shape (N x D), where D is the dimenseion of each particle alpha (float): weight of regularization for inverse kernel this parameter turns out to be crucial for convergence. Returns: :math:`\nabla\log q` (Tensor): the score function of shape (N x D) """ N, D = x.shape diff = x.unsqueeze(1) - x.unsqueeze(0) # [N, N, D] dist_sq = torch.sum(diff**2, -1) # [N, N] h, _ = torch.median(dist_sq.view(-1), dim=0) if h == 0.: h = torch.ones_like(h) else: h = h / max(np.log(N), 1.) kappa = torch.exp(-dist_sq / h) # [N, N] kappa_inv = torch.inverse(kappa + alpha * torch.eye(N)) # [N, N] kappa_grad = -2 * kappa.unsqueeze(-1) * diff / h # [N, N, D] kappa_grad = kappa_grad.sum(0) # [N, D] return kappa_inv @ kappa_grad def _ml_grad(self, particles, loss_func, entropy_regularization=None, transform_func=None): if transform_func is not None: particles, extra_particles, _ = transform_func(particles) aug_particles = torch.cat([particles, extra_particles], dim=-1) else: aug_particles = particles loss_inputs = particles loss = loss_func(loss_inputs) if isinstance(loss, tuple): neglogp = loss.loss else: neglogp = loss grad = torch.autograd.grad(neglogp.sum(), loss_inputs)[0] loss_propagated = torch.sum(grad.detach() * particles, dim=-1) return loss, loss_propagated def _svgd_grad(self, particles, loss_func, entropy_regularization, transform_func=None): """ Compute particle gradients via SVGD, empirical expectation evaluated using the all particles. """ if transform_func is not None: particles, extra_particles = transform_func(particles) aug_particles = torch.cat([particles, extra_particles], dim=-1) else: aug_particles = particles loss_inputs = particles loss = loss_func(loss_inputs) if isinstance(loss, tuple): neglogp = loss.loss else: neglogp = loss loss_grad = torch.autograd.grad(neglogp.sum(), loss_inputs)[0] # [N, D] # [N, N], [N, N, D] kernel_weight, kernel_grad = self._rbf_func(aug_particles.detach()) kernel_logp = torch.matmul(kernel_weight, loss_grad) / ( self.num_particles) # [N, D] loss_prop_kernel_logp = torch.sum( kernel_logp.detach() * particles, dim=-1) loss_prop_kernel_grad = torch.sum( -entropy_regularization * kernel_grad.mean(0).detach() * aug_particles, dim=-1) loss_propagated = loss_prop_kernel_logp + loss_prop_kernel_grad return loss, loss_propagated def _gfsf_grad(self, particles, loss_func, entropy_regularization, transform_func=None): """Compute particle gradients via GFSF (Stein estimator). """ if transform_func is not None: particles, extra_particles = transform_func(particles) aug_particles = torch.cat([particles, extra_particles], dim=-1) else: aug_particles = particles score_inputs = aug_particles.detach() loss_inputs = particles loss = loss_func(loss_inputs) if isinstance(loss, tuple): neglogp = loss.loss else: neglogp = loss loss_grad = torch.autograd.grad(neglogp.sum(), particles)[0] # [N, D] logq_grad = self._score_func(score_inputs) * entropy_regularization loss_prop_neglogp = torch.sum(loss_grad.detach() * particles, dim=-1) loss_prop_logq = torch.sum(-logq_grad.detach() * aug_particles, dim=-1) loss_propagated = loss_prop_neglogp + loss_prop_logq return loss, loss_propagated def _jacobian_trace(self, fx, x): """Hutchinson's trace Jacobian estimator O(1) call to autograd, used by ``minmax`` method""" assert fx.shape[-1] == x.shape[-1], ( "Jacobian is not square, no trace defined.") eps = torch.randn_like(fx) jvp = torch.autograd.grad( fx, x, grad_outputs=eps, retain_graph=True, create_graph=True)[0] tr_jvp = torch.einsum('bi,bi->b', jvp, eps) return tr_jvp def _critic_train_step(self, inputs, loss_func, entropy_regularization=1.): """ Compute the loss for critic training. """ loss = loss_func(inputs) if isinstance(loss, tuple): neglogp = loss.loss else: neglogp = loss loss_grad = torch.autograd.grad(neglogp.sum(), inputs)[0] # [N, D] outputs = self._critic.predict_step(inputs).output tr_gradf = self._jacobian_trace(outputs, inputs) # [N] f_loss_grad = (loss_grad.detach() * outputs).sum(1) # [N] loss_stein = f_loss_grad - entropy_regularization * tr_gradf # [N] l2_penalty = (outputs * outputs).sum(1).mean() * self._critic_l2_weight critic_loss = loss_stein.mean() + l2_penalty return critic_loss def _minmax_grad(self, particles, loss_func, entropy_regularization, transform_func=None): """ Compute particle gradients via minmax svgd (Fisher Neural Sampler). """ if transform_func is not None: aug_particles, extra_particles = transform_func(particles) else: aug_particles = particles for i in range(self._critic_iter_num): critic_inputs = aug_particles.detach().clone() critic_inputs.requires_grad = True critic_loss = self._critic_train_step(critic_inputs, loss_func, entropy_regularization) self._critic.update_with_gradient(LossInfo(loss=critic_loss)) loss_inputs = aug_particles loss = loss_func(loss_inputs.detach()) critic_outputs = self._critic.predict_step( aug_particles.detach()).output loss_propagated = torch.sum( -critic_outputs.detach() * aug_particles, dim=-1) return loss, loss_propagated