Source code for alf.algorithms.vae

# Copyright (c) 2019 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Variational auto encoder."""

from typing import Callable

import numpy as np

import torch
import torch.distributions as td
import torch.nn as nn

import alf
from alf.algorithms.algorithm import Algorithm
from alf.data_structures import AlgStep, LossInfo, namedtuple
from alf.layers import FC
from alf.networks import EncodingNetwork
from alf.tensor_specs import BoundedTensorSpec
from alf.utils import math_ops, dist_utils
from alf.utils.tensor_utils import tensor_extend_new_dim
from alf.utils.schedulers import ConstantScheduler, Scheduler

VAEInfo = namedtuple(
    "VAEInfo", ["kld", "z_std", "loss", "beta_loss", 'beta'], default_value=())
VAEOutput = namedtuple("VAEOutput", ["z", "z_mode", "z_std"], default_value=())


[docs]@alf.configurable class VariationalAutoEncoder(Algorithm): r"""VariationalAutoEncoder encodes data into diagonal multivariate gaussian, performs sampling with reparametrization trick, and returns KL divergence between posterior and prior. Mathematically: :math:`\log p(x) >= E_z \log P(x|z) - \beta KL(q(z|x) || prior(z))` ``train_step()`` method returns sampled z and KLD, it is up to the user of this class to use the returned z to decode and compute reconstructive loss to combine with kl loss returned here to optimize the whole network. See vae_test.py for example usages to train vanilla vae, conditional vae and vae with prior network on mnist dataset. """ def __init__(self, z_dim: int, input_tensor_spec: alf.NestedTensorSpec = None, preprocess_network: EncodingNetwork = None, z_prior_network: EncodingNetwork = None, beta: float = 1.0, target_kld_per_dim: float = None, beta_optimizer: torch.optim.Optimizer = None, checkpoint=None, name: str = "VariationalAutoEncoder"): """ Args: z_dim: dimension of latent vector ``z``, namely, the dimension for generating ``z_mean`` and ``z_log_var``. input_tensor_spec: the input spec which can be a nest. If `preprocess_network` is None, then it must be provided. preprocess_network: an encoding network to preprocess input data before projecting it into (mean, log_var). If ``z_prior_network`` is None, this network must be handle input with spec ``input_tensor_spec``. If ``z_prior_network`` is not None, this network must be handle input with spec ``(z_prior_network.input_tensor_spec, input_tensor_spec, z_prior_network.output_spec)``. If this is None, an MLP of hidden sizes ``(z_dim*2, z_dim*2)`` will be used. z_prior_network: an encoding network that outputs concatenation of a prior mean and prior log var given the prior input. The network shouldn't activate its output. beta: the weight for KL-divergence target_kld_per_dim: if not None, then this will be used as the target KLD per dim to automatically tune beta. beta_optimizer: if not None, will be used to train beta. checkpoint (None|str): a string in the format of "prefix@path", where the "prefix" is the multi-step path to the contents in the checkpoint to be loaded. "path" is the full path to the checkpoint file saved by ALF. Refer to ``Algorithm`` for more details. name (str): """ super(VariationalAutoEncoder, self).__init__( checkpoint=checkpoint, name=name) self._preprocess_network = preprocess_network if preprocess_network is None: # according to appendix 2.4-2.5 in paper: https://arxiv.org/pdf/1803.10760.pdf if z_prior_network is None: preproc_input_spec = input_tensor_spec else: preproc_input_spec = (z_prior_network.input_tensor_spec, input_tensor_spec, z_prior_network.output_spec) self._preprocess_network = EncodingNetwork( input_tensor_spec=preproc_input_spec, preprocessing_combiner=alf.nest.utils.NestConcat(), fc_layer_params=(2 * z_dim, 2 * z_dim), activation=torch.tanh, ) self._z_prior_network = z_prior_network size = self._preprocess_network.output_spec.shape[0] self._z_mean = FC(input_size=size, output_size=z_dim) self._z_log_var = FC(input_size=size, output_size=z_dim) self._log_beta = nn.Parameter(torch.tensor(beta).log()) self._target_kld = None if target_kld_per_dim is not None: self._target_kld = target_kld_per_dim * z_dim self._z_dim = z_dim if beta_optimizer is not None: self.add_optimizer(beta_optimizer, [self._log_beta]) def _sampling_forward(self, inputs): """Encode the data into latent space then do sampling. Args: inputs (nested Tensor): if a prior network is provided, this is a tuple of ``(prior_input, new_observation)``. Returns: tuple: - z (Tensor): ``z`` is a tensor of shape (``B``, ``z_dim``). - kl_loss (Tensor): ``kl_loss`` is a tensor of shape (``B``,). """ if self._z_prior_network: prior_input, new_obs = inputs prior_z_mean_and_log_var, _ = self._z_prior_network(prior_input) prior_z_mean = prior_z_mean_and_log_var[..., :self._z_dim] prior_z_log_var = prior_z_mean_and_log_var[..., self._z_dim:] inputs = (prior_input, new_obs, prior_z_mean_and_log_var) latents, _ = self._preprocess_network(inputs) z_mean = self._z_mean(latents) z_log_var = self._z_log_var(latents) if self._z_prior_network: kl_div_loss = math_ops.square(z_mean) / torch.exp(prior_z_log_var) + \ torch.exp(z_log_var) - z_log_var - 1.0 z_mean = z_mean + prior_z_mean z_log_var = z_log_var + prior_z_log_var else: kl_div_loss = math_ops.square(z_mean) + torch.exp( z_log_var) - 1.0 - z_log_var kl_div_loss = 0.5 * torch.sum(kl_div_loss, dim=-1) # reparameterization sampling: z = u + var ** 0.5 * eps eps = torch.randn(z_mean.shape) z_std = torch.exp(z_log_var * 0.5) z = z_mean + z_std * eps output = VAEOutput(z=z, z_std=z_std, z_mode=z_mean) return output, kl_div_loss
[docs] def train_step(self, inputs, state=()): """ Args: inputs (nested Tensor): data to be encoded. If there is a prior network, then ``inputs`` is a tuple of ``(prior_input, new_obs)``. state (Tensor): empty tuple () Returns: AlgStep: - output (VAEOutput): - state: empty tuple () - info (VAEInfo): """ output, kld_loss = self._sampling_forward(inputs) beta = self._log_beta.exp().detach() info = VAEInfo(loss=beta * kld_loss, kld=kld_loss, z_std=output.z_std) if self._target_kld is not None: beta_loss = self._beta_train_step(kld_loss) info = info._replace( beta_loss=beta_loss, loss=info.loss + beta_loss, beta=tensor_extend_new_dim(beta, 0, beta_loss.shape[0])) return AlgStep(output=output, state=state, info=info)
def _beta_train_step(self, kld_loss): beta_loss = self._log_beta * (self._target_kld - kld_loss).detach() return beta_loss
[docs]@alf.configurable class DiscreteVAE(VariationalAutoEncoder): r"""VAE with a discrete posterior distribution. The latent ``z`` might be a single categorical variable or a vector of categorials. Because the re-parameterization trick can no longer be applied to the discrete distribution, we instead use the straight-through (ST) gradient estimator to train the encoder. :: Bengio et al., "Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation", 2013. In short, we can re-parameterize the one-hot latent embedding :math:`z` as .. math:: \hat{z} = z + z_{prob} - SG(z_{prob}) Because :math:`z` is a sampled discrete variable, it has no gradient. So the parameter gradient is .. math:: \frac{\partial L}{\partial \hat{z}}\frac{\partial \hat{z}}{\partial \theta} = \frac{\partial L}{\partial \hat{z}}\frac{\partial z_{prob}}{\partial \theta} Alternatively, we provide the option of ST Gumbel Softmax gradient estimator. :: Jang et al., "CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX", 2017. Which applies the above ST trick to the Gumbel-softmax distribution that uses the Gumbel trick to reparameterize the categorical sampling process. The paper claims that ST Gumbel-softmax gradient estimator has a lower variance than the plain ST estimator. """ def __init__(self, z_spec: BoundedTensorSpec, input_tensor_spec: alf.NestedTensorSpec = None, z_network_cls: Callable = EncodingNetwork, prior_input_tensor_spec: alf.NestedTensorSpec = None, prior_z_network_cls: Callable = None, mode: str = "st", gumbel_temp_scheduler: Scheduler = ConstantScheduler(1.), beta: float = 1., target_kld_per_categorical: float = None, beta_optimizer: torch.optim.Optimizer = None, name: str = "DiscreteVAE"): """ Args: z_spec: a tensor spec for the discrete posterior. It has to be rank-one, representing a vector of discrete variables. The value bould of each variable must be identical and the lower bound has to be 0. input_tensor_spec: the input spec. z_network_cls: an encoding network to encode input data into a vector of logits. If ``prior_z_network_cls`` is None, this network must handle input with spec ``input_tensor_spec``. If ``prior_z_network_cls`` is not None, this network must be handle input with spec ``(prior_input_tensor_spec, input_tensor_spec, prior_z_network.output_spec)``. prior_input_tensor_spec: the input spec for ``prior_z_network``. prior_z_network_cls: an encoding network that outputs a vector of logits representing the a prior ``z`` distribution given the prior input. mode: either 'st' or 'st-gumbel'. gumbel_temp_scheduler: the temperature scheduler for gumbel-softmax. Only used when ``mode=='st-gumbel'``. beta: the weight for KL-divergence target_kld_per_categorical: if not None, then this will be used as the target KLD *per Categorical* to automatically tune beta. beta_optimizer: if not None, will be used to train beta. name (str): """ Algorithm.__init__(self, name=name) assert (z_spec.is_discrete and z_spec.ndim == 1 and z_spec.minimum == 0) self._n_categories = int(z_spec.maximum + 1) prior_z_network = None if prior_z_network_cls is not None: prior_z_network = prior_z_network_cls( input_tensor_spec=prior_input_tensor_spec, last_layer_size=z_spec.numel * self._n_categories, last_activation=alf.math.identity) input_tensor_spec = (prior_input_tensor_spec, input_tensor_spec, prior_z_network.output_spec) self._prior_z_network = prior_z_network self._z_network = z_network_cls( input_tensor_spec=input_tensor_spec, last_layer_size=z_spec.numel * self._n_categories, last_activation=alf.math.identity) self._z_spec = z_spec assert mode in ['st', 'st-gumbel'], f"Wrong mode {mode}" self._mode = mode self._gumbel_temp_scheduler = gumbel_temp_scheduler self._log_beta = nn.Parameter(torch.tensor(beta).log()) self._target_kld = None if target_kld_per_categorical is not None: self._target_kld = target_kld_per_categorical * z_spec.numel if beta_optimizer is not None: self.add_optimizer(beta_optimizer, [self._log_beta]) @property def output_spec(self): """Because the output is a floating one-hot vector, the shape is rank-two. """ return BoundedTensorSpec( shape=self._z_spec.shape + (self._n_categories, ), minimum=0., maximum=1., dtype=torch.float32) def _kl_divergence(self, logits1, logits2=None): if logits2 is None: logits2 = torch.zeros_like(logits1) # assume uniform logits1 = torch.nn.functional.log_softmax(logits1, dim=-1) logits2 = torch.nn.functional.log_softmax(logits2, dim=-1) # The expectation is over the target distribution kld = torch.nn.functional.kl_div( input=logits2, target=logits1, reduction='none', log_target=True) return kld.sum(dim=(1, 2)) # [B,L,K] -> [B] def _sampling_forward(self, inputs): """Encode the data into latent space then do sampling. Args: inputs: if a prior network is provided, this is a tuple of ``(prior_input, new_observation)``. """ logits_shape = (-1, ) + self._z_spec.shape + (self._n_categories, ) if self._prior_z_network is not None: prior_input, new_obs = inputs prior_z_logits, _ = self._prior_z_network(prior_input) inputs = (prior_input, new_obs, prior_z_logits) prior_z_logits = prior_z_logits.reshape(logits_shape) z_logits, _ = self._z_network(inputs) z_logits = z_logits.reshape(logits_shape) if self._prior_z_network is not None: z_logits += prior_z_logits kl_div_loss = self._kl_divergence(z_logits, prior_z_logits) else: kl_div_loss = self._kl_divergence(z_logits) if self._mode == 'st': z_dist = dist_utils.OneHotCategoricalStraightThrough( logits=z_logits) else: z_dist = dist_utils.OneHotCategoricalGumbelSoftmax( hard_sample=True, tau=self._gumbel_temp_scheduler(), logits=z_logits) output = VAEOutput(z=z_dist.rsample(), z_mode=z_dist.mode) return output, kl_div_loss