# Copyright (c) 2019 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mutual Information Estimator."""
import math
import torch
import torch.distributions as td
import torch.nn.functional as F
import alf
from alf.algorithms.algorithm import Algorithm, AlgStep, LossInfo
from alf.layers import BatchSquash
from alf.networks import EncodingNetwork
from alf.nest import get_nest_batch_size
from alf.nest.utils import get_outer_rank, NestConcat
from alf.utils.averager import EMAverager, ScalarAdaptiveAverager
from alf.utils.data_buffer import DataBuffer
from alf.utils import common, math_ops
from alf.utils.dist_utils import DiagMultivariateNormal
[docs]class MIEstimator(Algorithm):
r"""Mutual Infomation Estimator.
Implements several mutual information estimator from
Belghazi et al `Mutual Information Neural Estimation
<http://proceedings.mlr.press/v80/belghazi18a/belghazi18a.pdf>`_
Hjelm et al `Learning Deep Representations by Mutual Information Estimation
and Maximization <https://arxiv.org/pdf/1808.06670.pdf>`_
Currently, 3 types of estimator are implemented, which are based on the
following variational lower bounds:
* *DV*: :math:`\sup_T E_P(T) - \log E_Q(\exp(T))`
* *KLD*: :math:`\sup_T E_P(T) - E_Q(\exp(T)) + 1`
* *JSD*: :math:`\sup_T -E_P(softplus(-T))) - E_Q(softplus(T)) + \log(4)`
* *ML*: :math:`\sup_q E_P(\log(q(y|x)) - \log(P(y)))`
where P is the joint distribution of X and Y, and Q is the product marginal
distribution of P. Both DV and KLD are lower bounds for :math:`KLD(P||Q)=MI(X, Y)`.
However, *JSD* is not a lower bound for mutual information, it is a lower
bound for :math:`JSD(P||Q)`, which is closely correlated with MI as pointed out in
Hjelm et al.
For *ML*, :math:`P(y)` is the margianl distribution of y, and it needs to be provided.
The current implementation uses a normal distribution with diagonal variance
for :math:`q(y|x)`. So it only support continous `y`. If :math:`P(y|x)` can be reasonably
approximated as an diagonal normal distribution and :math:`P(y)` is known,
then 'ML' may give better estimation for the mutual information.
Assumming the function class of T is rich enough to represent any function,
for *KLD* and *JSD*, T will converge to :math:`\log(\frac{P}{Q})` and hence
:math:`E_P(T)` can also be used as an estimator of :math:`KLD(P||Q)=MI(X,Y)`.
For *DV*, :math:`T` will converge to :math:`\log(\frac{P}{Q}) + c`, where
:math:`c=\log E_Q(\exp(T))`.
Among *DV*, *KLD* and *JSD*, *DV* and *KLD* seem to give a better estimation
of PMI than *JSD*. But *JSD* might be numerically more stable than *DV* and
*KLD* because of the use of softplus instead of exp. And *DV* is more stable
than *KLD* because of the logarithm.
Several strategies are implemented in order to estimate :math:`E_Q(\cdot)`:
* 'buffer': store :math:`y` to a buffer and randomly retrieve samples from
the buffer.
* 'double_buffer': stroe both :math:`x` and :math:`y` to buffers and randomly
retrieve samples from the two buffers.
* 'shuffle': randomly shuffle batch :math:`y`
* 'shift': shift batch :math:`y` by one sample, i.e.
``torch.cat([y[-1:, ...], y[0:-1, ...]], dim=0)``
* direct sampling: You can also provide the marginal distribution of :math:`y`
to ``train_step()``. In this case, sampler is ignored and samples of :math:`y`
for estimating :math:`E_Q(.)` are sampled from ``y_distribution``.
If you need the gradient of :math:`y`, you should use sampler 'shift' and
'shuffle'.
Among these, 'buffer' and 'shift' seem to perform better and 'shuffle'
performs worst. 'buffer' incurs additional storage cost. 'shift' has the
assumption that y samples from one batch are independent. If the additional
memory is not a concern, we recommend 'buffer' sampler so that there is no
need to worry about the assumption of independence.
``MIEstimator`` can be also used to estimate conditional mutual information
:math:`MI(X,Y|Z)` using *KLD*, *JSD* or *ML*. In this case, you should let
``x`` to represent :math:`X` and :math:`Z`, and ``y`` to represent :math:`Y`.
And when calling ``train_step()``, you need to provide ``y_distribution``
which is the distribution :math:`P(Y|z)`. Note that *DV* cannot be used for
estimating conditional mutual information. See ``mi_estimator_test.py`` for
an example.
"""
def __init__(self,
x_spec,
y_spec,
model=None,
fc_layers=(256, ),
sampler='buffer',
buffer_size=65536,
optimizer: torch.optim.Optimizer = None,
estimator_type='DV',
averager: EMAverager = None,
name="MIEstimator"):
"""
Args:
x_spec (nested TensorSpec): spec of ``x``
y_spec (nested TensorSpec): spec of ``y``
model (Network): can be called as ``model([x, y])`` and return a Tensor
with ``shape=[batch_size, 1]``. If None, a default MLP with
``fc_layers`` will be created.
fc_layers (tuple[int]): size of hidden layers. Only used if model is
None.
sampler (str): type of sampler used to get samples from marginal
distribution, should be one of ``['buffer', 'double_buffer',
'shuffle', 'shift']``.
buffer_size (int): capacity of buffer for storing y for sampler
'buffer' and 'double_buffer'.
optimzer (torch.optim.Optimzer): optimizer
estimator_type (str): one of 'DV', 'KLD' or 'JSD'
averager (EMAverager): averager used to maintain a moving average
of :math:`exp(T)`. Only used for 'DV' estimator. If None,
a ScalarAdaptiveAverager will be created.
name (str): name of this estimator
"""
assert estimator_type in ['ML', 'DV', 'KLD', 'JSD'
], "Wrong estimator_type %s" % estimator_type
super().__init__(train_state_spec=(), optimizer=optimizer, name=name)
self._x_spec = x_spec
self._y_spec = y_spec
if model is None:
if estimator_type == 'ML':
model = EncodingNetwork(
name="MIEstimator",
input_tensor_spec=x_spec,
fc_layer_params=fc_layers,
preprocessing_combiner=NestConcat(dim=-1))
else:
model = EncodingNetwork(
name="MIEstimator",
input_tensor_spec=[x_spec, y_spec],
preprocessing_combiner=NestConcat(dim=-1),
fc_layer_params=fc_layers,
last_layer_size=1,
last_activation=math_ops.identity)
self._model = model
self._type = estimator_type
if sampler == 'buffer':
self._y_buffer = DataBuffer(y_spec, capacity=buffer_size)
self._sampler = self._buffer_sampler
elif sampler == 'double_buffer':
self._x_buffer = DataBuffer(x_spec, capacity=buffer_size)
self._y_buffer = DataBuffer(y_spec, capacity=buffer_size)
self._sampler = self._double_buffer_sampler
elif sampler == 'shuffle':
self._sampler = self._shuffle_sampler
elif sampler == 'shift':
self._sampler = self._shift_sampler
else:
raise TypeError("Wrong type for sampler %s" % sampler)
if estimator_type == 'DV':
if averager is None:
averager = ScalarAdaptiveAverager()
self._mean_averager = averager
if estimator_type == 'ML':
assert isinstance(
y_spec,
alf.TensorSpec), ("Currently, 'ML' does "
"not support nested y_spec: %s" % y_spec)
assert y_spec.is_continuous, (
"Currently, 'ML' does "
"not support discreted y_spec: %s" % y_spec)
hidden_size = self._model.output_spec.shape[-1]
self._delta_loc_layer = alf.layers.FC(
hidden_size,
y_spec.shape[-1],
kernel_initializer=torch.nn.init.zeros_,
bias_init_value=0.0)
self._delta_scale_layer = alf.layers.FC(
hidden_size,
y_spec.shape[-1],
kernel_initializer=torch.nn.init.zeros_,
bias_init_value=math.log(math.e - 1))
def _buffer_sampler(self, x, y):
batch_size = get_nest_batch_size(y)
if self._y_buffer.current_size >= batch_size:
y1 = self._y_buffer.get_batch(batch_size)
self._y_buffer.add_batch(y)
else:
self._y_buffer.add_batch(y)
y1 = self._y_buffer.get_batch(batch_size)
return x, common.detach(y1)
def _double_buffer_sampler(self, x, y):
batch_size = get_nest_batch_size(y)
self._x_buffer.add_batch(x)
x1 = self._x_buffer.get_batch(batch_size)
self._y_buffer.add_batch(y)
y1 = self._y_buffer.get_batch(batch_size)
return x1, y1
def _shuffle_sampler(self, x, y):
return x, math_ops.shuffle(y)
def _shift_sampler(self, x, y):
def _shift(y):
return torch.cat([y[-1:, ...], y[0:-1, ...]], dim=0)
return x, alf.nest.map_structure(_shift, y)
[docs] def train_step(self, inputs, y_distribution=None, state=None):
"""Perform training on one batch of inputs.
Args:
inputs (tuple(nested Tensor, nested Tensor)): tuple of ``x`` and ``y``
y_distribution (nested td.Distribution): distribution
for the marginal distribution of ``y``. If None, will use the
sampling method ``sampler`` provided at constructor to generate
the samples for the marginal distribution of :math:`Y`.
state: not used
Returns:
AlgStep:
- outputs (Tensor): shape is ``[batch_size]``, its mean is the
estimated MI for estimator 'KL', 'DV' and 'KLD', and
Jensen-Shannon divergence for estimator 'JSD'
- state: not used
- info (LossInfo): ``info.loss`` is the loss
"""
x, y = inputs
if self._type == 'ML':
return self._ml_step(x, y, y_distribution)
num_outer_dims = get_outer_rank(x, self._x_spec)
batch_squash = BatchSquash(num_outer_dims)
x = batch_squash.flatten(x)
y = batch_squash.flatten(y)
if y_distribution is None:
x1, y1 = self._sampler(x, y)
else:
x1 = x
y1 = y_distribution.sample()
y1 = batch_squash.flatten(y1)
log_ratio = self._model([x, y])[0].squeeze(-1)
t1 = self._model([x1, y1])[0].squeeze(-1)
if self._type == 'DV':
ratio = torch.min(t1, torch.tensor(20.)).exp()
mean = ratio.mean().detach()
if self._mean_averager:
self._mean_averager.update(mean)
unbiased_mean = self._mean_averager.get().detach()
else:
unbiased_mean = mean
# estimated MI = reduce_mean(mi)
# ratio/mean-1 does not contribute to the final estimated MI, since
# mean(ratio/mean-1) = 0. We add it so that we can have an estimation
# of the variance of the MI estimator
mi = log_ratio - (mean.log() + ratio / mean - 1)
loss = ratio / unbiased_mean - log_ratio
elif self._type == 'KLD':
ratio = torch.min(t1, torch.tensor(20.)).exp()
mi = log_ratio - ratio + 1
loss = -mi
elif self._type == 'JSD':
mi = -F.softplus(-log_ratio) - F.softplus(t1) + math.log(4)
loss = -mi
mi = batch_squash.unflatten(mi)
loss = batch_squash.unflatten(loss)
return AlgStep(output=mi, state=(), info=LossInfo(loss, extra=()))
def _ml_pmi(self, x, y, y_distribution):
num_outer_dims = get_outer_rank(x, self._x_spec)
hidden = self._model(x)[0]
batch_squash = BatchSquash(num_outer_dims)
hidden = batch_squash.flatten(hidden)
delta_loc = self._delta_loc_layer(hidden)
delta_scale = F.softplus(self._delta_scale_layer(hidden))
delta_loc = batch_squash.unflatten(delta_loc)
delta_scale = batch_squash.unflatten(delta_scale)
y_given_x_dist = DiagMultivariateNormal(
loc=y_distribution.mean + delta_loc,
scale=y_distribution.stddev * delta_scale)
pmi = y_given_x_dist.log_prob(y) - y_distribution.log_prob(y).detach()
return pmi
def _ml_step(self, x, y, y_distribution):
pmi = self._ml_pmi(x, y, y_distribution)
return AlgStep(output=pmi, state=(), info=LossInfo(loss=-pmi))
[docs] def calc_pmi(self, x, y, y_distribution=None):
r"""Return estimated pointwise mutual information.
The pointwise mutual information is defined as:
.. math::
\log \frac{P(x|y)}{P(x)} = \log \frac{P(y|x)}{P(y)}
Args:
x (Tensor): x
y (Tensor): y
y_distribution (DiagMultivariateNormal): needs to be provided for
'ML' estimator.
Returns:
Tensor: pointwise mutual information between ``x`` and ``y``.
"""
if self._type == 'ML':
assert isinstance(y_distribution, DiagMultivariateNormal), (
"y_distribution should be a DiagMultivariateNormal")
return self._ml_pmi(x, y, y_distribution)
log_ratio = self._model([x, y])[0]
log_ratio = torch.squeeze(log_ratio, dim=-1)
if self._type == 'DV':
log_ratio -= self._mean_averager.get().log()
return log_ratio