# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SegmentTree."""
import math
import torch
import torch.nn as nn
import alf
from alf.nest.utils import convert_device
[docs]class SegmentTree(nn.Module):
"""
Data structure to allow efficient calculation of the summary statistics over a
segment of elements.
See https://en.wikipedia.org/wiki/Segment_tree for detail.
In this implementation, ``values[1]`` is the root. ``values[capacity: 2*capacity]``
are the leaves. The two children of an internal node ``values[i]`` are ``values[2*i]``
and ``values[2*i+1]``. And ``values[i]`` is set to ``op(values[2*i], values[2*i+1])``.
Each leaf represent a value set through ``__setitem__``. All the nodes of
tree are initialized to be zeros.
"""
def __init__(self,
capacity,
op,
dtype=torch.float32,
device="cpu",
name="SegmentTree"):
super().__init__()
self._name = name
self._device = device
with alf.device(self._device):
self.register_buffer("_values",
torch.zeros((2 * capacity, ), dtype=dtype))
self._op = op
self._capacity = capacity
self._leftmost_leaf = 1
self._depth = 0
while self._leftmost_leaf < capacity:
self._leftmost_leaf *= 2
if self._leftmost_leaf < capacity:
self._depth += 1
def __setitem__(self, indices, values):
"""Set the value of leaves and update the internal nodes.
Args:
indices (Tensor): 1-D int64 Tensor. Its values should be in range
[0, capacity).
values (Tensor): 1-D Tensor with the same shape as ``indices``
"""
def _step(indices):
"""
Calculate the parent value from its children.
"""
indices = torch.unique(indices >> 1)
left = self._values[indices * 2]
right = self._values[indices * 2 + 1]
self._values[indices] = op(left, right)
return indices
with alf.device(self._device):
indices = convert_device(indices)
values = convert_device(values)
assert indices.ndim == 1
assert values.ndim == 1
assert indices.shape == values.shape, (
"indices and values should be 1-D tensor with the same length. "
"Got %s and %s." % (indices.shape, values.shape))
op = self._op
indices, order = torch.sort(indices)
values = values[order]
assert indices[-1] < self._capacity
indices = self._index_to_leaf(indices)
self._values[indices] = values
num_large = (indices >= self._leftmost_leaf).to(torch.int64).sum()
if num_large > 0:
large_indices = indices[:num_large]
small_indices = indices[num_large:]
large_indices = _step(large_indices)
indices = torch.cat([large_indices, small_indices])
for _ in range(self._depth):
indices = _step(indices)
def __getitem__(self, idx):
"""Get the values of leaves.
Args:
idx (Tensor): 1-D int64 Tensor. Its values should be in range
[0, capacity).
Returns:
Tensor: with same shaps as idx.
"""
with alf.device(self._device):
idx = convert_device(idx)
assert 0 <= idx.min()
assert idx.max() < self._capacity
result = self._values[self._index_to_leaf(idx)]
return convert_device(result)
def _index_to_leaf(self, idx):
"""
Make sure idx=0 is the leftmost leaf.
"""
idx = idx + self._leftmost_leaf
idx = torch.where(idx >= 2 * self._capacity, idx - self._capacity, idx)
return idx
def _leaf_to_index(self, leaf):
idx = leaf - self._leftmost_leaf
idx = torch.where(idx < 0, idx + self._capacity, idx)
return idx
[docs] def summary(self):
"""The summary of the tree.
If ``op`` is ``torch.add``, it's the sum of all values.
If ``op`` is ``torch.min``, it's the min of all values.
If ``op`` is ``torch.max``, it's the max of all values.
Returns:
a scalar
"""
return convert_device(self._values[1])
[docs]class SumSegmentTree(SegmentTree):
"""SegmentTree with sum operation."""
def __init__(self,
capacity,
dtype=torch.float32,
device="cpu",
name="SumSegmentTree"):
super().__init__(
capacity, torch.add, dtype=dtype, device=device, name=name)
self._nnz = 0
def __setitem__(self, indices, values):
assert values.min() >= 0
leaves = self._index_to_leaf(indices)
nnz = (values != 0).sum() - (self._values[leaves] != 0).sum()
self._nnz += int(nnz.cpu().numpy())
super().__setitem__(indices, values)
@property
def nnz(self):
"""The number of non-zeros."""
return self._nnz
[docs] def find_sum_bound(self, thresholds):
"""
The result is an int64 Tensor with the same shape as `thresholds`.
result[i] is the minimum idx such that
thresholds[i] < values[0] + ... + values[idx]
values[result[i]] will never be 0.
Args:
thresholds (Tensor): 1-D Tensor. All the elements in `thresholds`
should be smaller than self.summary()
Returns:
Tensor: 1-D int64 Tensor with the same shape as ``thresholds``.
Note that if thresholds[i] == root, result[i] will be
the index of the non-zero value with the largest index.
Raises:
ValueError: If one or more of ``thresholds`` is greather than ``summary()``.
"""
def _step(indices, thresholds):
"""Choose one of the children of each index based on threshold.
If threshold is greater than or equal to the
left child, choose the right child and update threhsold to threhsold - left_child.
Otherwise choose left child and keep threshold unchanged.
"""
indices *= 2
left = self._values[indices]
right = self._values[indices + 1]
# The condition (thresholds >= left) * (right == 0) is only possible
# if the original threshold == summary(), we want to make sure we
# still get an index corresponding to non-zero value.
greater = (thresholds >= left) * (right > 0)
indices = torch.where(greater, indices + 1, indices)
thresholds = torch.where(greater, thresholds - left, thresholds)
return indices, thresholds
with alf.device(self._device):
if not torch.all(thresholds <= self.summary()):
raise ValueError("thresholds cannot "
"be greater than summary(): got %s vs. %s" %
(thresholds.max(), self.summary()))
thresholds = convert_device(thresholds)
indices = torch.ones_like(thresholds, dtype=torch.int64)
for _ in range(self._depth):
indices, thresholds = _step(indices, thresholds)
is_small = indices < self._capacity
num_small = is_small.to(torch.int64).sum()
if num_small > 0:
i = torch.where(is_small)[0]
small_indices = indices[i]
small_thresholds = thresholds[i]
small_indices, _ = _step(small_indices, small_thresholds)
indices[i] = small_indices
return convert_device(self._leaf_to_index(indices))
[docs]class MinSegmentTree(SegmentTree):
"""SegmentTree with min operation."""
def __init__(self,
capacity,
dtype=torch.float32,
device="cpu",
name="MinSegmentTree"):
super().__init__(capacity, torch.min, dtype, device=device, name=name)
[docs]class MaxSegmentTree(SegmentTree):
"""SegmentTree with max operation."""
def __init__(self,
capacity,
dtype=torch.float32,
device="cpu",
name="MaxSegmentTree"):
super().__init__(capacity, torch.max, dtype, device=device, name=name)