Source code for alf.algorithms.generator

# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A generic generator."""

import functools
import numpy as np
import torch
from torch.autograd.functional import jacobian

import alf
from alf.algorithms.algorithm import Algorithm
from alf.algorithms.mi_estimator import MIEstimator
from alf.data_structures import AlgStep, LossInfo, namedtuple
import alf.nest as nest
from alf.networks import Network, EncodingNetwork, ReluMLP
from alf.tensor_specs import TensorSpec
from alf.utils import common, math_ops
from alf.utils.averager import AdaptiveAverager

GeneratorLossInfo = namedtuple("GeneratorLossInfo",
                               ["generator", "mi_estimator", "inverse_mvp"])


[docs]@alf.configurable class CriticAlgorithm(Algorithm): """ Wrap a critic network as an Algorithm for flexible gradient updates called by the Generator when par_vi is 'minmax'. """ def __init__(self, input_tensor_spec, output_dim=None, hidden_layers=(3, 3), activation=torch.relu_, net: Network = None, use_relu_mlp=False, use_bn=True, optimizer=None, name="CriticAlgorithm"): """Create a CriticAlgorithm. Args: input_tensor_spec (TensorSpec): spec of inputs. output_dim (int): dimension of output, default value is input_dim. hidden_layers (tuple): size of hidden layers. activation (Callable): activation used for all critic layers. net (Network): network for predicting outputs from inputs. If None, a default one with hidden_layers will be created use_relu_mlp (bool): whether use ReluMLP as default net constrctor. Diagonals of Jacobian can be explicitly computed for ReluMLP. use_bn (bool): whether use batch norm for each critic layers. optimizer (torch.optim.Optimizer): (optional) optimizer for training. name (str): name of this CriticAlgorithm. """ if optimizer is None: optimizer = alf.optimizers.Adam(lr=1e-3) super().__init__(train_state_spec=(), optimizer=optimizer, name=name) self._use_relu_mlp = use_relu_mlp self._output_dim = output_dim if output_dim is None: self._output_dim = input_tensor_spec.shape[0] if net is None: if use_relu_mlp: net = ReluMLP( input_tensor_spec=input_tensor_spec, hidden_layers=hidden_layers) else: net = EncodingNetwork( input_tensor_spec=input_tensor_spec, fc_layer_params=hidden_layers, use_fc_bn=use_bn, activation=activation, last_layer_size=self._output_dim, last_activation=math_ops.identity, last_use_fc_bn=use_bn, name='Critic') self._net = net
[docs] def reset_net_parameters(self): for fc in self._net._fc_layers: fc.reset_parameters()
[docs] def predict_step(self, inputs, state=None, requires_jac_diag=False): """Predict for one step of inputs. Args: inputs (Tensor): inputs for prediction. state: not used. requires_jac_trace (bool): whether outputs diagonals of Jacobian. Returns: AlgStep: - output (Tensor): predictions or (predictions, diag_jacobian) if requires_jac_diag is True. - state: not used. """ if self._use_relu_mlp: outputs = self._net(inputs, requires_jac_diag=requires_jac_diag)[0] else: outputs = self._net(inputs)[0] return AlgStep(output=outputs, state=(), info=())
[docs]@alf.configurable class InverseMVPAlgorithm(Algorithm): r"""InverseMVP network Algorithm Maintain an encoding network that takes (z, vec) as input and predicts a matrix-vector product (mvp) of the form :math:`y=J^{-1}(z)*vec`, where :math:`J^{-1}(z)` is the inverse of the Jacobian matrix of some function :math:`f(z)`, and ``vec`` is a vector. This network is used in GPVI in computing the ``functional_gradient`` of the generator, where :math:`J^{-1}` is the inverse of the Jacobian of the generator function w.r.t. input noise :math:`z'`, and ``vec`` is the gradient of the kernel :math:`\nabla_{z'}k(z', z)`. Training of this network is done outside of the algorithm, where the network is trained to predict :math:`y` that minimize the objective :math:`||Jy - vec||^2. """ def __init__(self, input_dim, output_dim, hidden_size=100, num_hidden_layers=1, activation=torch.relu_, optimizer=None, name="InverseMVPAlgorithm"): r"""Create a InverseMVPAlgorithm. Args: input_dim (int): dimension of input z output_dim (int): output dimension, i.e., dimension of the mvp hidden_size (int): width of hidden layers num_hidden_layers (int): number of hidden layers after activation (Callable): activation used for all hidden layers. optimizer (torch.optim.Optimizer): (optional) optimizer for training. name (str): name of this Algorithm. """ assert input_dim <= output_dim if optimizer is None: optimizer = alf.optimizers.Adam(lr=1e-3) super().__init__(train_state_spec=(), optimizer=optimizer, name=name) kernel_initializer = functools.partial( alf.initializers.variance_scaling_init, gain=1.0 / 2.0, mode='fan_in', distribution='truncated_normal', nonlinearity=math_ops.identity) self._z_dim = input_dim self._vec_dim = output_dim z_spec = TensorSpec(shape=(self._z_dim, )) vec_spec = TensorSpec(shape=(self._vec_dim, )) self._net = EncodingNetwork( (z_spec, vec_spec), input_preprocessors=(torch.nn.Linear(self._z_dim, hidden_size), torch.nn.Linear(self._vec_dim, hidden_size)), preprocessing_combiner=alf.layers.NestConcat(), fc_layer_params=(2 * hidden_size, ) * num_hidden_layers, activation=activation, kernel_initializer=kernel_initializer, last_layer_size=output_dim, last_activation=math_ops.identity, name='InverseMVPNetwork')
[docs] def predict_step(self, inputs, state=None): """Predict for one step of inputs. Args: inputs (tuple of Tensors): inputs (z, vec) for prediction. - z (Tensor): of size [N2, K] or [N2, D], representing :math:`z'`, where K is self._z_dim and D is self._vec_dim. - vec (Tensor): of size [N2, D] or [N2, N, D], representing :math:`\nabla_{z'}k(z', z)` in GPVI. state: not used. Returns: AlgStep: - output (tuple of Tensors): predictions of InverseMVP network and the z_inputs, which is [:, :K] of z. - state: not used. """ z_inputs, vec = inputs assert z_inputs.ndim == 2 and z_inputs.shape[-1] >= self._z_dim assert vec.shape[-1] >= self._vec_dim assert z_inputs.shape[0] == vec.shape[0] if z_inputs.shape[-1] > self._z_dim: z_inputs = z_inputs[:, :self._z_dim] # [N2, K] if vec.ndim == 2: vec_inputs = vec elif vec.ndim == 3: # [N2, N, D] z_inputs = torch.repeat_interleave( z_inputs, vec.shape[1], dim=0) # [N2*N, K] vec_inputs = vec.reshape(vec.shape[0] * vec.shape[1], -1) # [N2*N, D] else: raise ValueError( "vec must be dimension 2 or 3, got dimension {}".format( vec.ndim)) if vec_inputs.shape[-1] > self._vec_dim: vec_inputs = vec_inputs[:, :self._vec_dim] outputs = (self._net((z_inputs, vec_inputs))[0], z_inputs) return AlgStep(output=outputs, state=(), info=())
[docs]@alf.configurable class Generator(Algorithm): r"""Generator Generator generates outputs given `inputs` (can be None) by transforming a random noise and input using `net`: .. code-block:: python outputs = net([noise, input]) if input is not None else net(noise) The generator is trained to minimize the following objective: :math:`E(loss\_func(net([noise, input]))) - entropy\_regulariztion \cdot H(P)` where P is the (conditional) distribution of outputs given the inputs implied by `net` and H(P) is the (conditional) entropy of P. If the loss is the (unnormalized) negative log probability of some distribution Q and the ``entropy_regularization`` is 1, this objective is equivalent to minimizing :math:`KL(P||Q)`. It uses two different ways to optimize `net` depending on ``entropy_regularization``: * ``entropy_regularization`` = 0: the minimization is achieved by simply minimizing loss_func(net([noise, inputs])) * entropy_regularization > 0: the minimization is achieved using amortized particle-based variational inference (ParVI), in particular, four ParVI methods are implemented: 1. amortized Stein Variational Gradient Descent (SVGD): Feng et al "Learning to Draw Samples with Amortized Stein Variational Gradient Descent" https://arxiv.org/pdf/1707.06626.pdf 2. amortized Wasserstein ParVI with Smooth Functions (GFSF): Liu, Chang, et al. "Understanding and accelerating particle-based variational inference." International Conference on Machine Learning. 2019. 3. amortized Fisher Neural Sampler with Hutchinson's estimator (MINMAX): Hu et at. "Stein Neural Sampler." https://arxiv.org/abs/1810.03545, 2018. 4. generative particle-based variational inference (GPVI) If ``functional_gradient`` is set to True, then GPVI is used. Ratzlaff, Bai, et al. "Generative Particle Variational Inference via Estimation of Functional Gradients." International Conference on Machine Learning. 2021. It also supports an additional optional objective of maximizing the mutual information between [noise, inputs] and outputs by using mi_estimator to prevent mode collapse. This might be useful for ``entropy_regulariztion`` = 0 as suggested in section 5.1 of the following paper: Hjelm et al `Learning Deep Representations by Mutual Information Estimation and Maximization <https://arxiv.org/pdf/1808.06670.pdf>` """ def __init__(self, output_dim, noise_dim=32, input_tensor_spec=None, hidden_layers=(256, ), net: Network = None, net_moving_average_rate=None, entropy_regularization=0., mi_weight=None, mi_estimator_cls=MIEstimator, par_vi=None, use_kernel_averager=False, functional_gradient=False, init_lambda=1., lambda_trainable=False, block_inverse_mvp=False, direct_jac_inverse=False, inverse_mvp_solve_iters=1, inverse_mvp_hidden_size=100, inverse_mvp_hidden_layers=1, critic_input_dim=None, critic_hidden_layers=(100, 100), critic_l2_weight=10., critic_iter_num=2, critic_relu_mlp=False, critic_use_bn=True, minmax_resample=True, critic_optimizer=None, inverse_mvp_optimizer=None, optimizer=None, lambda_optimizer=None, name="Generator"): r"""Create a Generator. Args: output_dim (int): dimension of output noise_dim (int): dimension of noise input_tensor_spec (nested TensorSpec): spec of inputs. If there is no inputs, this should be None. hidden_layers (tuple): sizes of hidden layers. net (Network): network for generating outputs from [noise, inputs] or noise (if inputs is None). If None, a default one with hidden_layers will be created net_moving_average_rate (float): If provided, use a moving average version of net to do prediction. This has been shown to be effective for GAN training (arXiv:1907.02544, arXiv:1812.04948). entropy_regularization (float): weight of entropy regularization. mi_weight (float): weight of mutual information loss. mi_estimator_cls (type): the class of mutual information estimator for maximizing the mutual information between [noise, inputs] and [outputs, inputs]. par_vi (string): ParVI methods, options are [``svgd``, ``svgd2``, ``svgd3``, ``gfsf``, ``minmax``], * svgd: empirical expectation of SVGD is evaluated by a single resampled particle. The main benefit of this choice is it supports conditional case, while all other options do not. * svgd2: empirical expectation of SVGD is evaluated by splitting half of the sampled batch. It is a trade-off between computational efficiency and convergence speed. * svgd3: empirical expectation of SVGD is evaluated by resampled particles of the same batch size. It has better convergence but involves resampling, so less efficient computaionally comparing with svgd2. * gfsf: wasserstein gradient flow with smoothed functions. It involves a kernel matrix inversion, so computationally most expensive, but in some case the convergence seems faster than svgd approaches. * minmax: Fisher Neural Sampler, optimal descent direction of the Stein discrepancy is solved by an inner optimization procedure in the space of L2 neural networks. use_kernel_averager (bool): whether or not to use a running average of the kernel bandwith for ParVI methods. functional_gradient (bool): whether or not to optimize the generator with GPVI. When True, the dimension of the jacobian of the generator function needs to be square -- therefore invertible. When the generator is not sqaure, we ensure this by sampling an input noise vector of the same size as the output, and only forwarding the first ``noise_dim`` components. We then add the full noise vector to the output, multiplied by the ``fullrank_diag_weight``. init_lambda (float): weight on direct input-output link added to the generator output. Only used for GPVI and GPVI_Plus when forcing full rank Jacobian. lambda_trainable (bool): whether to train ``lambda``. block_inverse_mvp(bool): whether to use the more efficient block form for inverse_mvp when ``functional_gradient`` is True. This option is recommended only when ``noise_dim`` < ``output_dim``. as it is equivalent to the default form when ``noise_dim`` is equal to ``output_dim``. inverse_mvp_solve_iters (int): number of iterations of inverse_mvp network training per single iteration of generator training. inverse_mvp_hidden_size (int): width of hidden layers in inverse_mvp network. inverse_mvp_hidden_layers (int): number of hidden layers in inverse_mvp network. critic_input_dim (int): dimension of critic input, used for ``minmax``. critic_hidden_layers (tuple): sizes of hidden layers of the critic, used for ``minmax``. critic_l2_weight (float): weight of L2 regularization in training the critic, used for ``minmax``. critic_iter_num (int): number of critic updates for each generator train_step, used for ``minmax``. critic_relu_mlp (bool): whether use ReluMLP as the critic constructor, used for ``minmax``. critic_use_bn (book): whether use batch norm for each layers of the critic, used for ``minmax``. minmax_resample (bool): whether resample the generator for each critic update, used for ``minmax``. critic_optimizer (torch.optim.Optimizer): Optimizer for training the critic, used for ``minmax``. inverse_mvp_optimizer (torch.optim.Optimizer): Optimizer for training the inverse_mvp network, used when ``functional_gradient`` is True. optimizer (torch.optim.Optimizer): (optional) optimizer for training lambda_optimizer (torch.optim.Optimizer): Optimizer for training the ``lambda``, used for GPVI and GPVI_Plus when ``lambda_trainable`` is True. name (str): name of this generator """ super().__init__(train_state_spec=(), optimizer=optimizer, name=name) self._output_dim = output_dim self._noise_dim = noise_dim self._entropy_regularization = entropy_regularization self._functional_gradient = functional_gradient self._par_vi = par_vi self._direct_jac_inverse = direct_jac_inverse if entropy_regularization == 0: self._grad_func = self._ml_grad else: if par_vi == 'gfsf': self._grad_func = self._gfsf_grad elif par_vi == 'svgd': self._grad_func = self._svgd_grad elif par_vi == 'svgd2': self._grad_func = self._svgd_grad2 elif par_vi == 'svgd3': self._grad_func = self._svgd_grad3 elif par_vi == 'minmax': if critic_input_dim is None: critic_input_dim = output_dim self._grad_func = self._minmax_grad self._critic_iter_num = critic_iter_num self._critic_l2_weight = critic_l2_weight self._critic_relu_mlp = critic_relu_mlp self._minmax_resample = minmax_resample self._critic = CriticAlgorithm( TensorSpec(shape=(critic_input_dim, )), hidden_layers=critic_hidden_layers, use_relu_mlp=critic_relu_mlp, use_bn=critic_use_bn, optimizer=critic_optimizer) else: raise ValueError("Unsupported par_vi method: %s" % par_vi) if functional_gradient: if net is not None: assert isinstance(net, ReluMLP), ( "only ReluMLP generator is supported for functional_gradient." ) if noise_dim == output_dim: force_fullrank = False block_inverse_mvp = False else: assert noise_dim < output_dim force_fullrank = True self._grad_func = self._rkhs_func_grad self._force_fullrank = force_fullrank init_lambda = float(init_lambda) assert init_lambda > 0, "init_lambda has to be positive!" if lambda_trainable: self._log_lambda = torch.nn.Parameter( torch.tensor(np.log(init_lambda))) if lambda_optimizer is None: lambda_optimizer = alf.optimizers.Adam(lr=1e-3) self.add_optimizer(lambda_optimizer, nest.flatten(self._log_lambda)) else: self._fixed_lambda = init_lambda self._lambda_trainable = lambda_trainable self._block_inverse_mvp = block_inverse_mvp if not direct_jac_inverse: self._inverse_mvp_solve_iters = inverse_mvp_solve_iters if inverse_mvp_optimizer is None: inverse_mvp_optimizer = alf.optimizers.Adam( lr=1e-4, weight_decay=1e-5) if block_inverse_mvp: inverse_mvp_output_dim = noise_dim else: inverse_mvp_output_dim = output_dim self._inverse_mvp = InverseMVPAlgorithm( noise_dim, inverse_mvp_output_dim, hidden_size=inverse_mvp_hidden_size, num_hidden_layers=inverse_mvp_hidden_layers, optimizer=inverse_mvp_optimizer) if use_kernel_averager: self._kernel_width_averager = AdaptiveAverager( tensor_spec=TensorSpec(shape=())) else: self._kernel_width_averager = None noise_spec = TensorSpec(shape=(noise_dim, )) if net is None: net_input_spec = noise_spec if functional_gradient: net = ReluMLP( net_input_spec, output_size=output_dim, hidden_layers=hidden_layers, name='Generator') else: if input_tensor_spec is not None: net_input_spec = [net_input_spec, input_tensor_spec] net = EncodingNetwork( input_tensor_spec=net_input_spec, fc_layer_params=hidden_layers, last_layer_size=output_dim, last_activation=math_ops.identity, name="Generator") self._mi_estimator = None self._input_tensor_spec = input_tensor_spec if mi_weight is not None: x_spec = noise_spec y_spec = TensorSpec((output_dim, )) if input_tensor_spec is not None: x_spec = [x_spec, input_tensor_spec] self._mi_estimator = mi_estimator_cls( x_spec, y_spec, sampler='shift') self._mi_weight = mi_weight self._net = net self._predict_net = None self._net_moving_average_rate = net_moving_average_rate if net_moving_average_rate: self._predict_net = net.copy(name="Generator_average") self._predict_net_updater = common.TargetUpdater( self._net, self._predict_net, tau=net_moving_average_rate) def _trainable_attributes_to_ignore(self): return ["_predict_net", "_critic"] @property def noise_dim(self): return self._noise_dim
[docs] def get_lambda(self, training=False): if self._lambda_trainable: cur_lambda = torch.exp(self._log_lambda) if not training: cur_lambda = cur_lambda.detach() return cur_lambda else: return self._fixed_lambda
def _predict(self, inputs=None, noise=None, batch_size=None, training=True): if inputs is None: assert self._input_tensor_spec is None if noise is None: assert batch_size is not None noise = torch.randn(batch_size, self._noise_dim) gen_inputs = noise else: nest.assert_same_structure(inputs, self._input_tensor_spec) batch_size = nest.get_nest_batch_size(inputs) if noise is None: noise = torch.randn(batch_size, self._noise_dim) else: assert noise.shape[0] == batch_size assert noise.shape[1] == self._noise_dim gen_inputs = [noise, inputs] if self._predict_net and not training: outputs = self._predict_net(gen_inputs)[0] else: if self._functional_gradient: if self._force_fullrank: fullrank_diag_weight = self.get_lambda(training=training) extra_noise = torch.randn( noise.shape[0], self._output_dim - self._noise_dim) outputs = self._net(gen_inputs)[0] # [B, D] gen_inputs = torch.cat((gen_inputs, extra_noise), dim=-1) # [B, D] outputs = outputs + fullrank_diag_weight * gen_inputs else: outputs = self._net(gen_inputs)[0] else: outputs = self._net(gen_inputs)[0] return outputs, gen_inputs
[docs] def predict_step(self, inputs=None, noise=None, batch_size=None, training=False, state=None): """Generate outputs given inputs. Args: inputs (nested Tensor): if None, the outputs is generated only from noise. noise (Tensor): input to the generator. batch_size (int): batch_size. Must be provided if inputs is None. Its is ignored if inputs is not None training (bool): whether train the generator. state: not used Returns: AlgStep: - output (Tensor): predictions with shape ``[batch_size, output_dim]`` - state: not used. """ outputs, _ = self._predict( inputs=inputs, noise=noise, batch_size=batch_size, training=training) return AlgStep(output=outputs, state=(), info=())
[docs] def train_step(self, inputs, loss_func, batch_size=None, transform_func=None, entropy_regularization=None, state=None): """ Args: inputs (nested Tensor): if None, the outputs is generated only from noise. loss_func (Callable): loss_func([outputs, inputs]) (loss_func(outputs) if inputs is None) returns a Tensor or namedtuple of tensors with field `loss`, which is a Tensor of shape [batch_size] a loss term for optimizing the generator. batch_size (int): batch_size. Must be provided if inputs is None. Its is ignored if inputs is not None. transform_func (Callable): transform function on generator's outputs. Used in function value based par_vi (currently supported by [``svgd2``, ``svgd3``, ``gfsf``]) for evaluating the network(s) parameterized by the generator's outputs (given by self._predict) on the training batch (predefined with transform_func). It can be called in two ways - transform_func(params): params is a tensor of parameters for a network, of shape ``[D]`` or ``[B, D]`` - ``B``: batch size - ``D``: length of network parameters In this case, transform_func first samples additional data besides the predefined training batch and then evaluate the network(s) parameterized by ``params`` on the training batch plus additional sampled data. - transform_func((params, extra_samples)): params is the same as above case and extra_samples is the tensor of additional sampled data. In this case, transform_func evaluates the network(s) parameterized by ``params`` on predefined training batch plus ``extra_samples``. It returns three tensors: - outputs: outputs of network parameterized by params evaluated on predined training batch. - density_outputs: outputs of network parameterized by params evaluated on additional sampled data. - extra_samples: additional sampled data, same as input extra_samples if called as transform_func((params, extra_samples)) entropy_regularization (float): weight of entropy regularization. state: not used Returns: AlgStep: - output (Tensor): predictions with shape ``[batch_size, output_dim]`` - info (LossInfo): loss """ outputs, gen_inputs = self._predict(inputs, batch_size=batch_size) if self._functional_gradient: outputs = (outputs, gen_inputs) if entropy_regularization is None: entropy_regularization = self._entropy_regularization loss, loss_propagated = self._grad_func( inputs, outputs, loss_func, entropy_regularization, transform_func) mi_loss = () if self._mi_estimator is not None: mi_step = self._mi_estimator.train_step([gen_inputs, outputs]) mi_loss = mi_step.info.loss loss_propagated = loss_propagated + self._mi_weight * mi_loss if self._functional_gradient: loss, inverse_mvp_loss = loss else: inverse_mvp_loss = () return AlgStep( output=outputs, state=(), info=LossInfo( loss=loss_propagated, extra=GeneratorLossInfo( generator=loss, mi_estimator=mi_loss, inverse_mvp=inverse_mvp_loss)))
def _ml_grad(self, inputs, outputs, loss_func, entropy_regularization=None, transform_func=None): assert transform_func is None, ( "function value based vi is not supported for ml_grad") loss_inputs = outputs if inputs is None else [outputs, inputs] loss = loss_func(loss_inputs) grad = torch.autograd.grad(loss.sum(), outputs)[0] loss_propagated = torch.sum(grad.detach() * outputs, dim=-1) return loss, loss_propagated def _kernel_width(self, dist): """Update kernel_width averager and get latest kernel_width. """ if dist.ndim > 1: dist = torch.sum(dist, dim=-1) assert dist.ndim == 1, "dist must have dimension 1 or 2." width, _ = torch.median(dist, dim=0) width = width / np.log(len(dist)) if self._kernel_width_averager is not None: self._kernel_width_averager.update(width) width = self._kernel_width_averager.get() return width def _rbf_func(self, x, y): """Compute RBF kernel, used by svgd_grad. """ d = (x - y)**2 d = torch.sum(d, -1) h = self._kernel_width(d) w = torch.exp(-d / h) return w def _rbf_func2(self, x, y): r""" Compute the rbf kernel and its gradient w.r.t. first entry :math:`K(x, y), \nabla_x K(x, y)`, used by svgd_grad2 and svgd_grad3. Args: x (Tensor): set of N particles, shape (Nx, ...), where Nx is the number of particles. y (Tensor): set of N particles, shape (Ny, ...), where Ny is the number of particles. Returns: :math:`K(x, y)` (Tensor): the RBF kernel of shape (Nx x Ny) :math:`\nabla_x K(x, y)` (Tensor): the derivative of RBF kernel of shape (Nx x Ny x D) """ Nx = x.shape[0] Ny = y.shape[0] x = x.view(Nx, -1) y = y.view(Ny, -1) Dx = x.shape[1] Dy = y.shape[1] assert Dx == Dy diff = x.unsqueeze(1) - y.unsqueeze(0) # [Nx, Ny, D] dist_sq = torch.sum(diff**2, -1) # [Nx, Ny] h = self._kernel_width(dist_sq.view(-1)) kappa = torch.exp(-dist_sq / h) # [Nx, Nx] kappa_grad = kappa.unsqueeze(-1) * (-2 * diff / h) # [Nx, Ny, D] return kappa, kappa_grad def _score_func(self, x, alpha=1e-5): r""" Compute the stein estimator of the score function :math:`\nabla\log q = -(K + \alpha I)^{-1}\nabla K`, used by gfsf_grad. Args: x (Tensor): set of N particles, shape (N x D), where D is the dimenseion of each particle alpha (float): weight of regularization for inverse kernel this parameter turns out to be crucial for convergence. Returns: :math:`\nabla\log q` (Tensor): the score function of shape (N x D) """ N, D = x.shape diff = x.unsqueeze(1) - x.unsqueeze(0) # [N, N, D] dist_sq = torch.sum(diff**2, -1) # [N, N] h, _ = torch.median(dist_sq.view(-1), dim=0) h = h / np.log(N) kappa = torch.exp(-dist_sq / h) # [N, N] kappa_inv = torch.inverse(kappa + alpha * torch.eye(N)) # [N, N] kappa_grad = -2 * kappa.unsqueeze(-1) * diff / h # [N, N, D] kappa_grad = kappa_grad.sum(0) # [N, D] return -kappa_inv @ kappa_grad def _svgd_grad(self, inputs, outputs, loss_func, entropy_regularization, transform_func=None): """ Compute particle gradients via SVGD, empirical expectation evaluated by a single resampled particle. """ outputs2, _ = self._predict(inputs, batch_size=outputs.shape[0]) assert transform_func is None, ( "function value based vi is not supported for svgd_grad") kernel_weight = self._rbf_func(outputs, outputs2) weight_sum = entropy_regularization * kernel_weight.sum() kernel_grad = torch.autograd.grad(weight_sum, outputs2)[0] loss_inputs = outputs2 if inputs is None else [outputs2, inputs] loss = loss_func(loss_inputs) if isinstance(loss, tuple): neglogp = loss.loss else: neglogp = loss weighted_loss = kernel_weight.detach() * neglogp loss_grad = torch.autograd.grad(weighted_loss.sum(), outputs2)[0] grad = loss_grad - kernel_grad loss_propagated = torch.sum(grad.detach() * outputs, dim=-1) return loss, loss_propagated def _svgd_grad2(self, inputs, outputs, loss_func, entropy_regularization, transform_func=None): """ Compute particle gradients via SVGD, empirical expectation evaluated by splitting half of the sampled batch. """ assert inputs is None, '"svgd2" does not support conditional generator' if transform_func is not None: outputs, extra_outputs, _ = transform_func(outputs) aug_outputs = torch.cat([outputs, extra_outputs], dim=-1) else: aug_outputs = outputs num_particles = outputs.shape[0] // 2 outputs_i, outputs_j = torch.split(outputs, num_particles, dim=0) aug_outputs_i, aug_outputs_j = torch.split( aug_outputs, num_particles, dim=0) loss_inputs = outputs_j loss = loss_func(loss_inputs) if isinstance(loss, tuple): neglogp = loss.loss else: neglogp = loss loss_grad = torch.autograd.grad(neglogp.sum(), loss_inputs)[0] # [Nj, D] # [Nj, Ni], [Nj, Ni, D'] kernel_weight, kernel_grad = self._rbf_func2(aug_outputs_j.detach(), aug_outputs_i.detach()) kernel_logp = torch.matmul(kernel_weight.t(), loss_grad) / num_particles # [Ni, D] loss_prop_kernel_logp = torch.sum( kernel_logp.detach() * outputs_i, dim=-1) loss_prop_kernel_grad = torch.sum( -entropy_regularization * kernel_grad.mean(0).detach() * aug_outputs_i, dim=-1) loss_propagated = loss_prop_kernel_logp + loss_prop_kernel_grad return loss, loss_propagated def _svgd_grad3(self, inputs, outputs, loss_func, entropy_regularization, transform_func=None): """ Compute particle gradients via SVGD, empirical expectation evaluated by resampled particles of the same batch size. """ assert inputs is None, '"svgd3" does not support conditional generator' num_particles = outputs.shape[0] outputs2, _ = self._predict(inputs, batch_size=num_particles) if transform_func is not None: outputs, extra_outputs, samples = transform_func(outputs) outputs2, extra_outputs2, _ = transform_func((outputs2, samples)) aug_outputs = torch.cat([outputs, extra_outputs], dim=-1) aug_outputs2 = torch.cat([outputs2, extra_outputs2], dim=-1) else: aug_outputs = outputs # [N, D'] aug_outputs2 = outputs2 # [N2, D'] loss_inputs = outputs2 loss = loss_func(loss_inputs) if isinstance(loss, tuple): neglogp = loss.loss else: neglogp = loss loss_grad = torch.autograd.grad(neglogp.sum(), loss_inputs)[0] # [N2, D] # [N2, N], [N2, N, D'] kernel_weight, kernel_grad = self._rbf_func2(aug_outputs2.detach(), aug_outputs.detach()) kernel_logp = torch.matmul(kernel_weight.t(), loss_grad) / num_particles # [N, D] loss_prop_kernel_logp = torch.sum( kernel_logp.detach() * outputs, dim=-1) loss_prop_kernel_grad = torch.sum( -entropy_regularization * kernel_grad.mean(0).detach() * aug_outputs, dim=-1) loss_propagated = loss_prop_kernel_logp + loss_prop_kernel_grad return loss, loss_propagated def _gfsf_grad(self, inputs, outputs, loss_func, entropy_regularization, transform_func=None): """Compute particle gradients via GFSF (Stein estimator). """ assert inputs is None, '"gfsf" does not support conditional generator' if transform_func is not None: outputs, extra_outputs, _ = transform_func(outputs) aug_outputs = torch.cat([outputs, extra_outputs], dim=-1) else: aug_outputs = outputs score_inputs = aug_outputs.detach() loss_inputs = outputs loss = loss_func(loss_inputs) if isinstance(loss, tuple): neglogp = loss.loss else: neglogp = loss loss_grad = torch.autograd.grad(neglogp.sum(), loss_inputs)[0] # [N2, D] logq_grad = self._score_func(score_inputs) * entropy_regularization loss_prop_neglogp = torch.sum(loss_grad.detach() * outputs, dim=-1) loss_prop_logq = torch.sum(logq_grad.detach() * aug_outputs, dim=-1) loss_propagated = loss_prop_neglogp + loss_prop_logq return loss, loss_propagated def _jacobian_trace(self, fx, x): """Hutchinson's trace Jacobian estimator O(1) call to autograd, used by "\"minmax\" method""" assert fx.shape[-1] == x.shape[-1], ( "Jacobian is not square, no trace defined.") eps = torch.randn_like(fx) jvp = torch.autograd.grad( fx, x, grad_outputs=eps, retain_graph=True, create_graph=True)[0] tr_jvp = torch.einsum('bi,bi->b', jvp, eps) return tr_jvp def _critic_train_step(self, inputs, loss_func, entropy_regularization=1.): """ Compute the loss for critic training. """ loss = loss_func(inputs) if isinstance(loss, tuple): neglogp = loss.loss else: neglogp = loss loss_grad = torch.autograd.grad(neglogp.sum(), inputs)[0] # [N, D] if self._critic_relu_mlp: critic_step = self._critic.predict_step( inputs, requires_jac_diag=True) outputs, jac_diag = critic_step.output tr_gradf = jac_diag.sum(-1) # [N] else: outputs = self._critic.predict_step(inputs).output tr_gradf = self._jacobian_trace(outputs, inputs) # [N] f_loss_grad = (loss_grad.detach() * outputs).sum(1) # [N] loss_stein = f_loss_grad - entropy_regularization * tr_gradf # [N] l2_penalty = (outputs * outputs).sum(1).mean() * self._critic_l2_weight critic_loss = loss_stein.mean() + l2_penalty return critic_loss def _minmax_grad(self, inputs, outputs, loss_func, entropy_regularization, transform_func=None): """ Compute particle gradients via minmax svgd (Fisher Neural Sampler). """ assert inputs is None, '"minmax" does not support conditional generator' # optimize the critic using resampled particles assert transform_func is None, ( "function value based vi is not supported for minmax_grad") num_particles = outputs.shape[0] for i in range(self._critic_iter_num): if self._minmax_resample: critic_inputs, _ = self._predict( inputs, batch_size=num_particles) else: critic_inputs = outputs.detach().clone() critic_inputs.requires_grad = True critic_loss = self._critic_train_step(critic_inputs, loss_func, entropy_regularization) self._critic.update_with_gradient(LossInfo(loss=critic_loss)) # compute amortized svgd loss = loss_func(outputs.detach()) critic_outputs = self._critic.predict_step(outputs.detach()).output loss_propagated = torch.sum(-critic_outputs.detach() * outputs, dim=-1) return loss, loss_propagated def _get_vec_for_jac_inv_vec_prod(self, z, vec): r""" Construct a vecor as input to the helper network for Jacobian-inverse vector product estimation, used for GPVI_Plus. Args: z (Tensor): of size [N2, K], input noise to the self._net vec (Tensor): of size [N2, N, D], representing :math:`\nabla_{z'}k(z', z)`. Returns: reshaped vec (Tensor): of shape [N2*N, K] z_repeat (Tensor): of shape [N2*N, K] """ vec_1 = vec[:, :, :self._noise_dim] # [N2, N, K] vec_2 = vec[:, :, self._noise_dim:] # [N2, N, D-K] z_repeat = torch.repeat_interleave(z, vec.shape[1], dim=0) # [N2*N, K] vjp, _ = self._net.compute_vjp( z_repeat, vec_2.reshape(-1, vec_2.shape[-1]), output_partial_idx=torch.arange( start=self._noise_dim, end=self._output_dim)) # [N2*N, K] vec = vec_1.reshape( -1, self._noise_dim) - vjp / self.get_lambda() # [N2*N, K] return vec, z_repeat # [N2*N, K] def _inverse_mvp_train_step(self, z, vec): r"""Compute the loss for inverse_mvp training. self._inverse_mvp solves an inverse problem for the amortized functional gradient vi method GPVI. For GPVI, it takes :math:`z'^{(1:k)}` and :math:`v=\nabla_{z'}K(z', z)` as input and outputs :math:`v^T(\partial f / \partial z')^{-1}`. For GPVI_plus, it takes :math:`z'^{(1:k)}` and :math:`v` as inputs and outputs :math:`v^T(\partial f^{(1:k)} / \partial z'^{(1:k)})^{-1}`, where :math:`v` can be :math:`\nabla_{z'^{(1:k)}}K(z', z)` or :math:`(\nabla_{z'^{(k:d)}}K(z', z))^T(\partial f^{(k:d)} / \partial z')^{-1}`. The training loss is given by :math:`\|(\partical f / \partial z')^T y - v\|^2`, where :math`y` denotes the output of self._inverse_mvp, and the first term is computed by vector-jacobian product (vjp) between the generator :math:`f` and :math`y`. Args: z (Tensor): of size [N2, D], representing :math:`z'` vec (Tensor): of size [N2, N, D], representing :math:`\nabla_{z'}k(z', z)`. Returns: inverse_mvp_loss (float) """ if self._force_fullrank and self._block_inverse_mvp: vec, z_repeat = self._get_vec_for_jac_inv_vec_prod( z[:, :self._noise_dim], vec) y, z_inputs = self._inverse_mvp.predict_step((z_repeat, vec)).output else: # [N2*N, D] or [N2*N, K] y, z_inputs = self._inverse_mvp.predict_step((z, vec)).output vec = vec.reshape(-1, self._output_dim) # [N2*N, D] if self._block_inverse_mvp: partial_idx = torch.arange(self._noise_dim) else: partial_idx = None jac_y, _ = self._net.compute_vjp( z_inputs, y, output_partial_idx=partial_idx) # [N2*N, D] or [N2*N, K] if self._force_fullrank: if not self._block_inverse_mvp: jac_y = torch.cat([ jac_y, torch.zeros(jac_y.shape[0], self._output_dim - self._noise_dim) ], dim=-1) jac_y += self.get_lambda() * y # [N2*N, D] loss = torch.nn.functional.mse_loss(jac_y, vec.detach()) return loss def _rkhs_func_grad(self, inputs, outputs, loss_func, entropy_regularization, transform_func=None): """ Compute the amortized functional gradient of generator, functional gradient represented in an RKHS. Empirical expectation evaluated by a resampling from the z space of the same batch size. Args: inputs: None outputs (tuple of Tensors): (outputs, gen_inputs) of size [N, D] and [N, K] respectively, where N being the sample size, D being the output dim of ReluMLP and K being the input dim of the generator. loss_func (callable) entropy_regularization (float): tradeoff parameter transform_func (callable): not used """ assert inputs is None, ( 'rkhs_func_grad does not support conditional generator') assert transform_func is None, ( "function value based vi is not supported for rkhs_func_grad") outputs, gen_inputs = outputs # [N, D], [N, D] num_particles = outputs.shape[0] outputs2, gen_inputs2 = self._predict( batch_size=num_particles) # [N2, D] # [N2, N], [N2, N, D] kernel_weight, kernel_grad = self._rbf_func2(gen_inputs2, gen_inputs) z_inputs = gen_inputs2[:, :self._noise_dim] # [N2, K] if self._direct_jac_inverse: # direct jac inverse, no inverse_mvp needed. J_inv_kernel_grad = self._direct_jac_inverse_vec_prod( z_inputs.detach(), kernel_grad.detach()) inverse_mvp_loss = () else: # train inverse_mvp for i in range(self._inverse_mvp_solve_iters): inverse_mvp_loss = self._inverse_mvp_train_step( gen_inputs2.detach(), kernel_grad.detach()) self._inverse_mvp.update_with_gradient( LossInfo(loss=inverse_mvp_loss)) # construct functional gradient via inverse_mvp if self._block_inverse_mvp: # [N2*N, K] vec, z_repeat = self._get_vec_for_jac_inv_vec_prod( z_inputs.detach(), kernel_grad.detach()) J_inv_kernel_grad_1, _ = self._inverse_mvp.predict_step( (z_repeat, vec)).output # [N2*N, K] J_inv_kernel_grad_1 = J_inv_kernel_grad_1.reshape( num_particles, num_particles, -1) # [N2, N, K] J_inv_kernel_grad = torch.cat( [J_inv_kernel_grad_1, kernel_grad[:, :, self._noise_dim:] \ / self.get_lambda()], dim=-1) # [N2, N, D] else: J_inv_kernel_grad, _ = self._inverse_mvp.predict_step( (gen_inputs2, kernel_grad)).output # [N2*N, D] J_inv_kernel_grad = J_inv_kernel_grad.reshape( num_particles, num_particles, -1) # [N2, N2, D] loss_inputs = outputs2 loss = loss_func(loss_inputs) if isinstance(loss, tuple): neglogp = loss.loss else: neglogp = loss loss_grad = torch.autograd.grad(neglogp.sum(), loss_inputs)[0] # [N2, D] kernel_logp = torch.matmul(kernel_weight.t(), loss_grad) / num_particles # [N, D] grad = kernel_logp - entropy_regularization * J_inv_kernel_grad.mean(0) loss_propagated = torch.sum(grad.detach() * outputs, dim=1) return (loss, inverse_mvp_loss), loss_propagated def _direct_jac_inverse_vec_prod(self, z, vec): r""" Compute Jacobian-inverse vector product through direct Jacobian Inversion, used for GPVI and GPVI_Plus. Args: z (Tensor): of size [N2, K], input noise to the self._net vec (Tensor): of size [N2, N, D], representing :math:`\nabla_{z'}k(z', z)`. Returns: J_inv_vec (Tensor): of shape [N2, N, D] """ fullrank_diag_weight = self.get_lambda() N2, N = vec.shape[:2] if self._block_inverse_mvp: partial_idx = torch.arange(self._noise_dim) else: partial_idx = None jac = self._net.compute_jac(z, output_partial_idx=partial_idx) if self._force_fullrank: if self._block_inverse_mvp: eye_dim = self._noise_dim else: eye_dim = self._output_dim jac = torch.cat([ jac, torch.zeros(*jac.shape[:-1], self._output_dim - self._noise_dim) ], dim=-1) jac += fullrank_diag_weight * torch.eye(eye_dim) jac_inv = torch.inverse(jac) # [N2, D, D] or [N2, K, K] if self._force_fullrank and self._block_inverse_mvp: vec_1 = vec[:, :, :self._noise_dim] J_inv_vec_1 = torch.einsum('bij,bai->baj', jac_inv, vec_1) # [N2, N, K] vec_2 = vec[:, :, self._noise_dim:] # [N2, N, D-K] z_repeat = torch.repeat_interleave(z, N, dim=0) # [N2*N, K] vjp, _ = self._net.compute_vjp( z_repeat, vec_2.reshape(-1, vec_2.shape[-1]), output_partial_idx=torch.arange( start=self._noise_dim, end=self._output_dim)) vjp = vjp.reshape(N2, N, -1) # [N2, N, K] J_inv_vec_1 = J_inv_vec_1 - vjp / fullrank_diag_weight J_inv_vec = torch.cat([J_inv_vec_1, vec_2 / fullrank_diag_weight], dim=-1) # [N2, N, D] else: J_inv_vec = torch.einsum('bij,bai->baj', jac_inv, vec) # [N2, N, D] return J_inv_vec
[docs] def after_update(self, training_info): if self._predict_net: self._predict_net_updater()