# Copyright (c) 2022 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vector Quantized Variational AutoEncoder Algorithm."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Callable
import alf
from alf.algorithms.algorithm import Algorithm
from alf.data_structures import AlgStep, LossInfo, namedtuple
from alf.networks import EncodingNetwork
VqvaeLossInfo = namedtuple(
"VqvaeLossInfo", ["quantization", "commitment", "reconstruction"],
default_value=())
[docs]class Vqvae(Algorithm):
r"""Vector Quantized Variational AutoEncoder (VQVAE) algorithm, described in:
::
A van den Oord et al. "Neural Discrete Representation Learning", NeurIPS 2017.
VQVAE is different from standard VAE mainly in the follows aspects:
1. Discrete latent is used, instead of continuous latent as in standard VAE.
2. Standard VAE uses Gaussian prior and posterior. VQVAE can be viewed as
using a determinstic form of posterior, which is a categorical
distribution with onehot samples computed by nearest neighbor matching
(Eq.1 of the paper). By using a uniform prior, the KL divergence is constant.
"""
def __init__(self,
input_tensor_spec: alf.NestedTensorSpec,
num_embeddings: int,
embedding_dim: int,
encoder_ctor: Callable = EncodingNetwork,
decoder_ctor: Callable = EncodingNetwork,
optimizer: torch.optim.Optimizer = None,
commitment_loss_weight: float = 1.0,
checkpoint=None,
debug_summaries: bool = False,
name: str = "Vqvae"):
"""
Args:
input_tensor_spec (TensorSpec): the tensor spec of
the input.
num_embeddings (int): the number of embeddings (size of codebook)
embedding_dim (int): the dimensionality of embedding vectors
encoder_ctor (Callable): called as ``encoder_ctor(observation_spec)``
to construct the encoding ``Network``. The network takes raw observation
as input and output the latent representation.
decoder_ctor (Callable): called as ``decoder_ctor(latent_spec)`` to
construct the decoder.
optimizer (Optimzer|None): if provided, it will be used to optimize
the parameter of encoder_net, decoder_net and embedding vectors.
commitment_loss_weight (float): the weight for commitment loss.
checkpoint (None|str): a string in the format of "prefix@path",
where the "prefix" is the multi-step path to the contents in the
checkpoint to be loaded. "path" is the full path to the checkpoint
file saved by ALF. Refer to ``Algorithm`` for more details.
"""
super().__init__(
checkpoint=checkpoint, debug_summaries=debug_summaries, name=name)
self._embedding_dim = embedding_dim
self._num_embeddings = num_embeddings
# [n, d]
self._embedding = torch.nn.Parameter(
torch.FloatTensor(self._num_embeddings, self._embedding_dim))
torch.nn.init.uniform_(
self._embedding,
a=-1 / self._num_embeddings,
b=1 / self._num_embeddings)
self._encoding_net = encoder_ctor(input_tensor_spec)
self._decoding_net = decoder_ctor(self._encoding_net.output_spec)
if optimizer is not None:
self.add_optimizer(
optimizer,
[self._encoding_net, self._decoding_net, self._embedding])
self._optimizer = optimizer
self._commitment_loss_weight = commitment_loss_weight
def _predict_step(self, inputs, state=()):
"""
Args:
inputs (tensor): with the shape the same as input_tensor_spec
"""
# [B, d]
input_embedding, _ = self._encoding_net(inputs)
# calculate distances
# [B, 1] + [n] + [B, n]
distances = (torch.sum(input_embedding**2, dim=1, keepdim=True) +
torch.sum(self._embedding**2, dim=1) -
2 * torch.matmul(input_embedding, self._embedding.t()))
encoding_indices = torch.argmin(distances, dim=1)
quantized = self._embedding[encoding_indices]
# straight through
quantized_st = input_embedding + (quantized - input_embedding).detach()
return input_embedding, quantized, quantized_st
[docs] def predict_step(self, inputs, state=()):
_, _, quantized_st = self._predict_step(inputs)
rec = self._decoding_net(quantized_st)[0]
return AlgStep(output=rec, state=state, info=quantized_st)
[docs] def train_step(self, inputs, state=()):
"""
Args:
inputs (tensor): with the shape the same as input_tensor_spec
"""
input_embedding, quantized, quantized_st = self._predict_step(inputs)
e_latent_loss = F.mse_loss(
quantized.detach(), input_embedding, reduction="none")
q_latent_loss = F.mse_loss(
quantized, input_embedding.detach(), reduction="none")
# encoding loss
enc_loss = (q_latent_loss +
self._commitment_loss_weight * e_latent_loss).mean(dim=1)
# decoding loss
rec = self._decoding_net(quantized_st)[0]
recon_loss = F.mse_loss(rec, inputs, reduction="none").mean(dim=1)
if self._debug_summaries and alf.summary.should_record_summaries():
with alf.summary.scope(self._name):
alf.summary.embedding("vq_embedding", self._embedding.detach())
loss = (enc_loss + recon_loss)
info = VqvaeLossInfo(
quantization=q_latent_loss.mean(1),
commitment=e_latent_loss.mean(1),
reconstruction=recon_loss)
loss_info = LossInfo(loss=loss, extra=info)
return AlgStep(output=rec, state=state, info=loss_info)