# Copyright (c) 2019 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Various function/classes related to loss computation."""
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Callable
from scipy.optimize import linear_sum_assignment
import alf
from alf.utils.math_ops import InvertibleTransform, binary_neg_entropy
from alf.utils import summary_utils
[docs]@alf.configurable
def element_wise_huber_loss(x, y):
"""Elementwise Huber loss.
Args:
x (Tensor): label
y (Tensor): prediction
Returns:
loss (Tensor)
"""
return F.smooth_l1_loss(y, x, reduction="none")
[docs]@alf.configurable
def element_wise_squared_loss(x, y):
"""Elementwise squared loss.
Args:
x (Tensor): label
y (Tensor): prediction
Returns:
loss (Tensor)
"""
return F.mse_loss(y, x, reduction="none")
[docs]@alf.configurable
def huber_function(x: torch.Tensor, delta: float = 1.0):
"""Huber function.
Args:
x: difference between the observed and predicted values
delta: the threshold at which to change between delta-scaled
L1 and L2 loss, must be positive. Default value is 1.0
Returns:
Huber function (Tensor)
"""
return torch.where(x.abs() <= delta, 0.5 * x**2,
delta * (x.abs() - 0.5 * delta))
[docs]@alf.configurable
def multi_quantile_huber_loss(quantiles: torch.Tensor,
target: torch.Tensor,
delta: float = 0.1) -> torch.Tensor:
"""Multi-quantile Huber loss
The loss for simultaneous multiple quantile regression. The number of quantiles
n is ``quantiles.shape[-1]``. ``quantiles[..., k]`` is the quantile value
estimation for quantile :math:`(k + 0.5) / n`. For each prediction, there
can be one or multiple target values.
This loss is described in the following paper:
`Dabney et. al. Distributional Reinforcement Learning with Quantile Regression
<https://www.aaai.org/ocs/index.php/AAAI/AAAI18/paper/viewFile/17184/16590>`_
Args:
quantiles: batch_shape + [num_quantiles,]
target: batch_shape or batch_shape + [num_targets, ]
delta: the smoothness parameter for huber loss (larger means smoother).
Note that the quantile estimation with delta > 0 is biased. You should
use a small value for ``delta`` if you want the quantile estimation
to be less biased (so that the mean of the quantile will be close
to mean of the samples).
Returns:
loss of batch_shape
"""
num_quantiles = quantiles.shape[-1]
t = torch.arange(0.5 / num_quantiles, 1., 1. / num_quantiles)
if target.ndim == quantiles.ndim - 1:
target = target.unsqueeze(-1)
assert quantiles.shape[:-1] == target.shape[:-1]
# [B, num_quantiles, num_samples]
d = target[..., :, None] - quantiles[..., None, :]
if delta == 0.0:
loss = (t - (d < 0).float()) * d
else:
c = (t - (d < 0).float()).abs()
d_abs = d.abs()
loss = c * torch.where(d_abs < delta,
(0.5 / delta) * d**2, d_abs - 0.5 * delta)
return loss.mean(dim=(-2, -1))
[docs]class ScalarPredictionLoss(object):
def __call__(self, pred: torch.Tensor, target: torch.Tensor):
"""Calculate the loss given ``pred`` and ``target``.
Args:
pred: raw prediction
target: target value
Returns:
loss with the same shape as target
"""
raise NotImplementedError()
[docs] def calc_expectation(self, pred: torch.Tensor):
"""Calculate the expected predition in the untransfomred domain from ``pred``.
"""
raise NotImplementedError()
[docs] def initialize_bias(self, bias: torch.Tensor):
"""Initialize the bias of the last FC layer for the prediction properly.
This function can be passed to FC as bias_initializer.
For some losses (e.g. OrderedDiscreteRegresion), initializing bias to
zero can have very bad initial predictions. So we provide an interface
for doing loss specific intializations. Note that the weight of the last
FC should be initialized to zero in general.
Args:
bias: the bias parameter to be initialized.
"""
with torch.no_grad():
bias.zero_()
[docs]@alf.repr_wrapper
class SquareLoss(ScalarPredictionLoss):
"""Square loss for predicting scalar target.
Args:
transform: the transformation applied to target. If it is provided, the
the regression target will be transformed.
"""
def __init__(self, transform: Optional[InvertibleTransform] = None):
super().__init__()
self._transform = transform
def __call__(self, pred: torch.Tensor, target: torch.Tensor):
"""Calculate the loss.
Args:
pred: shape is [B]
target: the shape is [B]
Returns:
loss with the same shape as target
"""
assert pred.shape == target.shape
if self._transform is not None:
target = self._transform.transform(target)
return (pred - target)**2
[docs] def calc_expectation(self, pred: torch.Tensor):
"""Calculate the expected predition in the untransfomred domain from ``pred``.
Args:
pred: raw model prediction
"""
if self._transform is not None:
pred = self._transform.inverse_transform(pred)
return pred
def _get_indexer(shape: Tuple[int]):
"""Return a tuple of Tensors which can be used to index a Tensor.
The purpose of this function can be better illustrated by an example.
Suppose ``shape`` is ``[n0, n1, n2]``. Then shape of the three returned
Tensors will be: [n0, 1, 1], [n1, 1], [n2]. And each of them has elements
ranging from 0 to n0-1, n1-1 and n2-1 respectively. The returned tuple ``B``
can be combined with another int64 Tensor ``I`` of shape ``[n0, n1, n2]``
to access the element of a Tensor ``X`` with shape``[n0, n1, n2, n]`` as
``Y=X[B + (I,)]`` so that ``Y[i,j,k] = X[i, j, k, I[i,j,k]]``
Args:
shape: The shape of the tensor to be accessed exclusing the last dimension.
Returns:
the tuple of index for accessing the tensor.
"""
ndim = len(shape)
ones = [1] * ndim
B = tuple(
torch.arange(d).reshape(d, *ones[i + 1:]) for i, d in enumerate(shape))
return B
class _DiscreteRegressionLossBase(ScalarPredictionLoss):
"""The base class for DiscreteRegressionLoss and OrderedDiscreteRegresionLoss."""
def __init__(self,
transform: Optional[InvertibleTransform] = None,
inverse_after_mean=False):
super().__init__()
self._transform = transform
if self._transform is not None:
self._inverse_after_mean = inverse_after_mean
else:
self._inverse_after_mean = True
self._support = None
def _calc_support(self, n: int):
if self._support is not None and self._support.shape[0] == n:
return self._support
upper_bound = n // 2
lower_bound = -((n - 1) // 2)
x = torch.arange(lower_bound, upper_bound + 1, dtype=torch.float32)
if self._transform is not None and not self._inverse_after_mean:
x = self._transform.inverse_transform(x)
self._support = x
return x
def _calc_bin(self, logits, target):
"""Discretize ``target`` such that:
bin1 <= transform(target) - lower_bound < bin2
and w2 is the weight assign to bin2. Hence 1 - w1 is the weight assigned
to bin1.
If inverse_after_mean is False, w2 is chosen so that the expectation
will be equal to target.
If inverse_after_mean is True, w2 is simply ``transform(target) - lower_bound - bin1``
"""
assert logits.shape[:-1] == target.shape
n = logits.shape[-1]
lower_bound = -((n - 1) // 2)
upper_bound = n // 2
original_target = target
if self._transform is not None:
target = self._transform.transform(target)
target = target.clamp(min=lower_bound, max=upper_bound)
low = target.floor()
high = low + 1
bin1 = low.to(torch.int64) - lower_bound
bin2 = (bin1 + 1).clamp(max=n - 1)
if self._inverse_after_mean:
w2 = target - low
else:
low = self._transform.inverse_transform(low)
high = self._transform.inverse_transform(high)
w2 = (original_target - low) / (high - low)
# Due to limited numerical precision, w2 may be slightly out of the range
# of [0, 1]. So we clamp it to the right range.
w2 = w2.clamp(0, 1)
return bin1, bin2, w2
[docs]@alf.repr_wrapper
class DiscreteRegressionLoss(_DiscreteRegressionLossBase):
r"""A loss for predicting the distribution of a scalar.
The target is assumed to be in the range ``[-(n-1)//2, n//2]``, where ``n=logits.shape[-1]``.
The logits are used to calculate the probabilities of being one of the ``n``
values. If a target value y is not an integer, it is treated as having
prabability mass of :math:`y- \lfloor y \rfloor` at :math:`\lfloor y \rfloor + 1`
and probability mass of :math:`1 + \lfloor y \rfloor - y` at :math:`\lfloor y \rfloor`.
Then cross entropy loss is applied.
More specifically, the ``logits`` passed to ``calc_loss`` represents the following:
P = softmax(logits) and P[i] means the probability that the (transformed)
``target`` is equal to ``i - (n-1)//2``
Note: ``DescreteRegressionLoss(SqrtLinearTransform(0.001), inverse_after_mean=True)``
is the loss used by MuZero paper.
Args:
transform: the transformation applied to target. If it is provided, the
the regression target will be transformed.
inverse_after_mean: when calculating the expected prediction, whether to
do the inverse transformation after calculating the the expectation
in the transformed space. Note that using ``inverse_after_mean=True``
will make the expectation biased in general. This is because
:math:`f^{-1}(E(x)) \le E(f^{-1}(x))` (Jensen inequality) if
:math:`f^{-1}` is convex. In our case, :math:`f^{-1}` is convex for
:math:`x \ge 0`.
"""
def __call__(self, logits: torch.Tensor, target: torch.Tensor):
"""Caculate the loss.
Args:
logits: shape is [B, n]
target: the shape is [B]
Returns:
loss with the same shape as target
"""
bin1, bin2, w2 = self._calc_bin(logits, target)
w1 = 1 - w2
nlp = -F.log_softmax(logits, dim=-1)
B = _get_indexer(logits.shape[:-1])
loss = w1 * nlp[B + (bin1, )] + w2 * nlp[B + (bin2, )]
neg_entropy = w1.xlogy(w1) + w2.xlogy(w2)
return (loss + neg_entropy).relu()
[docs] def calc_expectation(self, logits):
"""Calculate the expected predition in the untransfomred domain from ``pred``.
Args:
pred: raw model prediction
"""
support = self._calc_support(logits.shape[-1])
ret = torch.mv(logits.softmax(dim=-1), support)
if self._inverse_after_mean and self._transform is not None:
ret = self._transform.inverse_transform(ret)
return ret
[docs] def initialize_bias(self, bias: torch.Tensor):
r"""Initialize the bias of the last FC layer for the prediction properly.
This function set the bias so that the initial distribution of the prediction
in the original domain of target is approximatedly Cauchy: :math:`p(x) \propto \frac{1}{1+x^2}`
Args:
bias: the bias parameter to be initialized.
"""
assert bias.ndim == 1
n = bias.shape[0]
upper_bound = n // 2
lower_bound = -((n - 1) // 2)
x = torch.arange(lower_bound, upper_bound + 1, dtype=torch.float32)
x1 = x - 0.5
x2 = x + 0.5
if self._transform is not None:
x = self._transform.inverse_transform(x)
x1 = self._transform.inverse_transform(x1)
x2 = self._transform.inverse_transform(x2)
probs = (x2 - x1) / (x**2 + 1)
probs = probs / probs.sum()
with torch.no_grad():
bias.copy_(probs.log())
[docs]@alf.repr_wrapper
class OrderedDiscreteRegressionLoss(_DiscreteRegressionLossBase):
r"""A loss for predicting the distribution of a scalar.
The target is assumed to be in the range ``[-(n-1)//2, n//2]``, where ``n=logits.shape[-1]``.
The logits are used to calculate the probabilities of being greater than or
equal to each of these ``n`` values. If a target value y is not an integer,
it is treated as having prabability mass of :math:`y- \lfloor y \rfloor` at
:math:`\lfloor y \rfloor + 1` and probability mass of :math:`1 + \lfloor y \rfloor - y`
at :math:`\lfloor y \rfloor`. Then binary cross entropy loss is applied.
More specifically, the ``logits`` passed to ``calc_loss`` represents the following:
P = sigmoid(logits) and P[i] means the probability that the (transformed)
``target`` is greater than or equal to ``i - (n-1)//2``
Args:
transform: the transformation applied to target. If it is provided, the
the regression target will be transformed.
inverse_after_mean: when calculating the expected prediction, whether to
do the inverse transformation after calculating the the expectation
in the transformed space. Note that using ``inverse_after_mean=True``
will make the expectation biased in general. This is because
:math:`f^{-1}(E(x)) \le E(f^{-1}(x))` (Jensen inequality) if
:math:`f^{-1}` is convex. In our case, :math:`f^{-1}` is convex for
:math:`x \ge 0`.
"""
def __call__(self, logits: torch.Tensor, target: torch.Tensor):
"""Caculate the loss.
Args:
logits: shape is [B, n]
target: the shape is [B]
Returns:
loss with the same shape as target
"""
n = logits.shape[-1]
bin1, bin2, w2 = self._calc_bin(logits, target)
w = F.one_hot(bin1, num_classes=n).to(logits.dtype)
w = 1 - w.cumsum(dim=-1)
B = _get_indexer(target.shape)
w[B + (bin2, )] = w2
w[B + (bin1, )] = 1
cross_entropy = F.binary_cross_entropy_with_logits(
logits, w, reduction='none')
kld = cross_entropy + binary_neg_entropy(w)
return kld.relu().sum(dim=-1)
[docs] def calc_expectation(self, logits: torch.Tensor):
"""Calculate the expected predition in the untransfomred domain from ``pred``.
Args:
pred: raw model prediction
"""
n = logits.shape[-1]
lower_bound = -((n - 1) // 2)
logits = logits.cummin(dim=-1).values
probs = logits.sigmoid()
if self._inverse_after_mean:
pred = probs.sum(dim=-1) + (lower_bound - 1)
if self._transform is not None:
pred = self._transform.inverse_transform(pred)
else:
probs = torch.cat(
[probs[..., :-1] - probs[..., 1:], probs[..., -1:]], dim=-1)
support = self._calc_support(logits.shape[-1])
pred = torch.mv(probs, support)
return pred
[docs] def initialize_bias(self, bias: torch.Tensor):
r"""Initialize the bias of the last FC layer for the prediction properly.
This function set the bias so that the initial distribution of the prediction
in the original domain of target is approximatedly Cauchy: :math:`p(x) \propto \frac{1}{1+x^2}`
Args:
bias: the bias parameter to be initialized.
"""
assert bias.ndim == 1
n = bias.shape[0]
upper_bound = n // 2
lower_bound = -((n - 1) // 2)
# Use float64 to prevent precision loss due to cumsum
x = torch.arange(lower_bound, upper_bound + 1, dtype=torch.float64)
x1 = x - 0.5
x2 = x + 0.5
if self._transform is not None:
x = self._transform.inverse_transform(x)
x1 = self._transform.inverse_transform(x1)
x2 = self._transform.inverse_transform(x2)
probs = (x2 - x1) / (x**2 + 1)
probs = probs / probs.sum()
probs = probs.cumsum(dim=0)
probs = torch.cat(
[torch.tensor([1e-20], dtype=torch.float64), probs[:-1]], dim=0)
with torch.no_grad():
bias.copy_(((1 - probs) / probs).log().to(torch.float32))
[docs]@alf.repr_wrapper
class QuantileRegressionLoss(ScalarPredictionLoss):
r"""Multi-quantile Huber loss
The loss for simultaneous multiple quantile regression. The number of quantiles
n is ``quantiles.shape[-1]``. ``quantiles[..., k]`` is the quantile value
estimation for quantile :math:`(k + 0.5) / n`. For each prediction, there
can be one or multiple target values.
This loss is described in the following paper:
`Dabney et. al. Distributional Reinforcement Learning with Quantile Regression
<https://www.aaai.org/ocs/index.php/AAAI/AAAI18/paper/viewFile/17184/16590>`_
Args:
transform: the transformation applied to target. If it is provided, the
the regression target will be transformed.
inverse_after_mean: when calculating the expected prediction, whether to
do the inverse transformation after calculating the the expectation
in the transformed space. Note that using ``inverse_after_mean=True``
will make the expectation biased in general. This is because
:math:`f^{-1}(E(x)) \le E(f^{-1}(x))` (Jensen inequality) if
:math:`f^{-1}` is convex. In our case, :math:`f^{-1}` is convex for
:math:`x \ge 0`.
delta: the smoothness parameter for huber loss (larger means smoother).
Note that the quantile estimation with delta > 0 is biased. You should
use a small value for ``delta`` if you want the quantile estimation
to be less biased (so that the mean of the quantile will be close
to mean of the samples).
"""
def __init__(self,
transform: Optional[InvertibleTransform] = None,
inverse_after_mean: bool = False,
delta: float = 0.0):
super().__init__()
self._transform = transform
self._delta = delta
self._inverse_after_mean = inverse_after_mean
def __call__(self, quantiles: torch.Tensor, target: torch.Tensor):
"""Calculate the loss.
Args:
quantiles: batch_shape + [num_quantiles,]
target: batch_shape or batch_shape + [num_targets, ]
Returns:
loss whose shape is batch_shape
"""
assert quantiles.shape[:-1] == target.shape
if self._transform is not None:
target = self._transform.transform(target)
return multi_quantile_huber_loss(quantiles, target, delta=self._delta)
[docs] def calc_expectation(self, quantiles: torch.Tensor):
"""Calculate the expected predition in the untransfomred domain from ``pred``.
Args:
quantiles: predicted quantile values in the transformed space.
"""
if self._transform is not None:
if self._inverse_after_mean:
return self._transform.inverse_transform(
quantiles.mean(dim=-1))
else:
return self._transform.inverse_transform(quantiles).mean(
dim=-1)
else:
return quantiles.mean(dim=-1)
[docs]@alf.repr_wrapper
class AsymmetricSimSiamLoss(nn.Module):
"""The siamese loss proposed in:
Chen Xinlei et. al. "Exploring Simple Siamese Representation Learning" CVPR 2021
The loss is ``1-cosine(pred(proj(x), detach(proj(y)))``, where x is the predicted
representation, y is the target representation, and pred and proj are computed
using ``proj_net`` and ``pred_net``.
Args:
proj_net: if not provided, a default MLP with two hidden layers and
output size as ``output_size`` will be created.
pred_net: if not provided, a default MLP with one hidden layer will
be created.
input_size: input size of ``proj_net``
proj_hidden_size: the size of the hidden layers of proj_net. Only useful
if ``proj_net`` is not provided.
pred_hidden_size: the size of the hidden layer of pred_net. Only useful
if ``pred_net`` is not provided.
proj_last_use_bn: whether to use batch norm for the output layer of
proj_net. Only useful if ``proj_net`` is not provided
eps: the ``eps`` for calling ``F.normalize()`` when calculating the
normalized vector in order to calculate cosine.
fixed_weight_norm: whether to fix the norm of the weight parameter of
the FC layers.
lr: learning rate. If None, the default learning rate will be used.
debug_summaries: whether to write debug summaries
name: name of this loss
"""
def __init__(self,
proj_net: Optional[alf.nn.Network] = None,
pred_net: Optional[alf.nn.Network] = None,
input_size: Optional[int] = None,
proj_hidden_size: int = 256,
pred_hidden_size: int = 128,
output_size: int = 256,
proj_last_use_bn: bool = False,
eps: float = 1e-5,
fixed_weight_norm: bool = False,
lr: Optional[float] = None,
debug_summaries: bool = True,
name: str = "SimSiamLoss"):
super().__init__()
if proj_net is None:
assert input_size is not None, "input_size must be provided if proj_net is not given"
proj_net = alf.nn.Sequential(
alf.layers.Reshape(-1),
alf.layers.FC(
input_size,
proj_hidden_size,
activation=torch.relu_,
use_bn=True,
weight_opt_args=dict(fixed_norm=fixed_weight_norm, lr=lr)),
alf.layers.FC(
proj_hidden_size,
proj_hidden_size,
activation=torch.relu_,
use_bn=True,
weight_opt_args=dict(fixed_norm=fixed_weight_norm, lr=lr)),
alf.layers.FC(
proj_hidden_size,
output_size,
use_bn=proj_last_use_bn,
weight_opt_args=dict(
lr=lr,
fixed_norm=fixed_weight_norm and proj_last_use_bn)),
input_tensor_spec=alf.TensorSpec((input_size, )))
output_size = proj_net.output_spec.numel
if pred_net is None:
pred_net = alf.nn.Sequential(
alf.layers.FC(
output_size,
pred_hidden_size,
activation=torch.relu_,
use_bn=True,
weight_opt_args=dict(lr=lr, fixed_norm=fixed_weight_norm)),
alf.layers.FC(
pred_hidden_size,
output_size,
weight_opt_args=dict(lr=lr, fixed_norm=False)))
self._proj_net = proj_net
self._pred_net = pred_net
self._eps = eps
self._debug_summaries = debug_summaries
self._name = name
[docs] @alf.summary.enter_summary_scope
def forward(self, pred: Tensor, target: Tensor) -> Tensor:
"""Calculate the loss.
Args:
pred: predicted representation of shape [B, T, ...]
target: target representation of shape [B, T, ...]
Returns:
loss of shape [B, T]
"""
assert pred.shape == target.shape
B, T = pred.shape[:2]
target = target.reshape(B * T, *target.shape[2:])
pred = pred.reshape(B * T, *pred.shape[2:])
if self._debug_summaries and alf.summary.should_record_summaries():
pred = summary_utils.summarize_tensor_gradients(
"pred_grad", pred, clone=True)
with torch.no_grad():
projected_target = self._proj_net(target.to(pred.dtype))[0]
norm_projected_target = F.normalize(
projected_target.detach(), dim=1, eps=self._eps)
projected_pred = self._proj_net(pred)[0]
predicted_projected_pred = self._pred_net(projected_pred)[0]
norm_predicted_projected_pred = F.normalize(
predicted_projected_pred, dim=1, eps=self._eps)
cos = (norm_projected_target * norm_predicted_projected_pred).sum(
dim=1)
if self._debug_summaries and alf.summary.should_record_summaries():
summary_utils.add_mean_hist_summary("cos", cos)
summary_utils.add_mean_hist_summary(
"predicted_projected_pred_norm",
predicted_projected_pred.norm(dim=1))
return (1 - cos).reshape(B, T)
[docs]@alf.repr_wrapper
class MeanSquaredLoss(object):
"""Mean squared loss.
For a prediction and target pair (x,y), the loss is ``((x - y) ** 2).mean()``.
Args:
batch_dims: the first so many dims of prediction and target are treated
as batch dimension. The mean is performed on the rest of the dimensions.
"""
def __init__(self,
batch_dims: int = 1,
debug_summaries: bool = True,
name: str = "MSELoss"):
super().__init__()
self._debug_summaries = debug_summaries
self._name = name
self._batch_dims = batch_dims
[docs] @alf.summary.enter_summary_scope
def forward(self, pred: Tensor, target: Tensor) -> Tensor:
"""Calculate the loss.
Args:
pred: prediction of shape [B, ...]
target: target of shape [B, ...]
Returns:
loss of shape [B]
"""
assert pred.shape == target.shape
if self._debug_summaries and alf.summary.should_record_summaries():
pred = summary_utils.summarize_tensor_gradients(
"pred_grad", pred, clone=True)
ndim = pred.ndim
assert ndim >= self._batch_dims
loss = (pred - target)**2
if ndim > self._batch_dims:
loss = loss.mean(dim=list(range(self._batch_dims, ndim)))
return loss
[docs]class BipartiteMatchingLoss(object):
r"""Bipartite matching loss.
This order-invariant loss can be used to evaluate the matching between a predicted
set and a target set. The idea is that for every forward, an optimal one-to-one
mapping assignment from the predicted set to the target set is first found using
some efficient bipartite graph matching algorithm, and the optimal loss is
minimized.
Mathematically, suppose there are :math:`N` objects in either set,
:math:`L(x,y)` is the matching loss between any :math:`(x,y)` object pair,
and :math:`\mathcal{G}_N` is the permuation space. The forward loss to be
minimized is:
.. math::
\min_{g\in\mathcal{G}_N}\sum_n^N L(x_n(\theta),y_{g(n)})
where :math:`\theta` is the model parameters.
In practice, to find the optimal assignment, we simply use ``scipy.optimize.linear_sum_assignment``.
References::
`End-to-End Object Detection with Transformers <https://arxiv.org/pdf/2005.12872.pdf>`_, Carion et al.
`<https://github.com/facebookresearch/detr/blob/main/models/matcher.py>`_
"""
def __init__(self,
reduction: str = 'mean',
name: str = "BipartiteMatchingLoss"):
"""
Args:
reduction: 'sum', 'mean' or 'none'. This is how to reduce the matching
loss. For the former two, the loss shape is ``[B]``, while for
the 'none', the loss shape is ``[B,N]``.
"""
super().__init__()
self._reduction = reduction
assert reduction in ['mean', 'sum', 'none']
self._name = name
[docs] def forward(self,
matching_cost_mat: torch.Tensor,
cost_mat: torch.Tensor = None):
"""Compute the optimal matching loss.
Args:
matching_cost_mat: the cost matrix used to determine the optimal
matching. It shape should be ``[B,N,N]``.
cost_mat: the cost matrix used to compute the optimal loss once the
optimal matching is found. According to the DETR paper, this
cost matrix might be different from the one used for matching.
If None, then it will be the same matrix for matching.
Returns:
tuple:
- the optimal loss. If reduction is 'mean' or 'sum', its shape is
``[B]``, otherwise its shape is ``[B,N]``.
- the optimal matching given the cost matrix. Its shape is ``[B,N]``,
where the value of n-th entry is its mapped index in the target set.
"""
if cost_mat is None:
cost_mat = matching_cost_mat
with torch.no_grad():
B, N = matching_cost_mat.shape[:2]
max_cost = matching_cost_mat.max() + 1.
# [B*N, B*N]
# Subtract all diag entries by a max cost so that no off-diag matchings
# will be optimal.
big_cost_mat = torch.block_diag(
*list(matching_cost_mat - max_cost))
np_big_cost_mat = big_cost_mat.cpu().numpy()
# col_ind: [B*N]
row_ind, col_ind = linear_sum_assignment(np_big_cost_mat)
col_ind = col_ind % N
col_ind = col_ind.reshape(B, N, 1)
col_ind = torch.tensor(col_ind).to(cost_mat.device)
# [B,N]
optimal_loss = cost_mat.gather(dim=-1, index=col_ind).squeeze(-1)
if self._reduction == 'mean':
optimal_loss = optimal_loss.mean(-1)
elif self._reduction == 'sum':
optimal_loss = optimal_loss.sum(-1)
return optimal_loss, col_ind.squeeze(-1)
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)