# Copyright (c) 2019 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classes for storing data for sampling."""
import functools
from multiprocessing import Event, RLock
import time
import torch
import torch.nn as nn
import alf
from alf.nest import get_nest_batch_size
from alf.tensor_specs import TensorSpec
from alf.nest.utils import convert_device
[docs]def atomic(func):
"""Make class member function atomic by checking ``class._lock``.
Can only be applied on class methods, whose containing class
must have ``_lock`` set to ``None`` or a ``multiprocessing.Lock`` object.
Args:
func (callable): the function to be wrapped.
Returns:
the wrapped function
"""
def atomic_deco(func):
@functools.wraps(func)
def atomic_wrapper(self, *args, **kwargs):
lock = getattr(self, '_lock')
if lock:
with lock:
return func(self, *args, **kwargs)
else:
return func(self, *args, **kwargs)
return atomic_wrapper
return atomic_deco(func)
[docs]class RingBuffer(nn.Module):
"""Batched Ring Buffer.
Multiprocessing safe, optionally via: ``allow_multiprocess`` flag, blocking
modes to ``enqueue`` and ``dequeue``, a stop event to terminate blocked
processes, and putting buffer into shared memory.
This is the underlying implementation of ``ReplayBuffer`` and ``Queue``.
Different from ``tf_agents.replay_buffers.tf_uniform_replay_buffer``, this
buffer allows users to specify the environment id when adding batch.
Thus, multiple actors can store experience in the same buffer.
Once stop event is set, all blocking ``enqueue`` and ``dequeue`` calls that
happen afterwards will be skipped, unless the operation already started.
Terminology: we use ``pos`` as in ``_current_pos`` to refer to the always
increasing position of an element in the infinitly long buffer, and ``idx``
as the actual index of the element in the underlying store (``_buffer``).
That means ``idx == pos % _max_length`` is always true, and one should use
``_buffer[idx]`` to retrieve the stored data.
"""
def __init__(self,
data_spec,
num_environments,
max_length=1024,
device="cpu",
allow_multiprocess=False,
name="RingBuffer"):
"""
Args:
data_spec (nested TensorSpec): spec describing a single item that
can be stored in this buffer.
num_environments (int): number of environments or total batch size.
max_length (int): The maximum number of items that can be stored
for a single environment.
device (str): A torch device to place the Variables and ops.
allow_multiprocess (bool): if ``True``, allows multiple processes
to write and read the buffer asynchronously.
name (str): name of the replay buffer.
"""
super().__init__()
self._name = name
self._max_length = max_length
self._num_envs = num_environments
self._device = device
self._allow_multiprocess = allow_multiprocess
# allows outside to stop enqueue and dequeue processes from waiting
self._stop = Event()
if allow_multiprocess:
self._lock = RLock() # re-entrant lock
# notify a finished dequeue event, so blocked enqueues may start
self._dequeued = Event()
self._dequeued.set()
# notify a finished enqueue event, so blocked dequeues may start
self._enqueued = Event()
self._enqueued.clear()
else:
self._lock = None
self._dequeued = None
self._enqueued = None
buffer_id = [0]
def _create_buffer(spec_path, tensor_spec):
buf = tensor_spec.zeros((num_environments, max_length))
if spec_path != '':
# buffer name cannot contain '.', which is used as the delimiter
# by ``py_map_structure_with_path`` in the generated path
spec_name = spec_path.replace('.', '|')
self.register_buffer(spec_name, buf)
else:
self.register_buffer("_buffer%s" % buffer_id[0], buf)
buffer_id[0] += 1
return buf
with alf.device(self._device):
self.register_buffer(
"_current_size",
torch.zeros(num_environments, dtype=torch.int64))
# Current *ending* positions of data in the buffer without modulo.
# The next experience will be stored at this position after modulo.
# These pos always increases. To get the index in the RingBuffer,
# use ``circular()``, e.g. ``last_idx = self.circular(pos - 1)``.
self.register_buffer(
"_current_pos", torch.zeros(
num_environments, dtype=torch.int64))
self._buffer = alf.nest.py_map_structure_with_path(
_create_buffer, data_spec)
self._flattened_buffer = alf.nest.map_structure(
lambda x: x.view(-1, *x.shape[2:]), self._buffer)
if allow_multiprocess:
self.share_memory()
@property
def device(self):
"""The device where the data is stored in."""
return self._device
[docs] def circular(self, pos):
"""Mod pos by _max_length to get the actual index in the _buffer."""
return pos % self._max_length
[docs] def check_convert_env_ids(self, env_ids):
with alf.device(self._device):
if env_ids is None:
env_ids = torch.arange(self._num_envs)
else:
env_ids = env_ids.to(torch.int64)
env_ids = convert_device(env_ids)
assert len(env_ids.
shape) == 1, "env_ids {}, should be a 1D tensor".format(
env_ids.shape)
return env_ids
[docs] def has_space(self, env_ids):
"""Check free space for one batch of data for env_ids.
Args:
env_ids (Tensor): Assumed not ``None``, properly checked by
``check_convert_env_ids()``.
Returns:
bool
"""
current_size = self._current_size[env_ids]
max_size = current_size.max()
return max_size < self._max_length
[docs] def enqueue(self, batch, env_ids=None, blocking=False):
"""Add a batch of items to the buffer.
Note, when ``blocking == False``, it always succeeds, overwriting
oldest data if there is no free slot.
Args:
batch (Tensor): of shape ``[batch_size] + tensor_spec.shape``
env_ids (Tensor): If ``None``, ``batch_size`` must be
``num_environments``. If not ``None``, its shape should be
``[batch_size]``. We assume there are no duplicate ids in
``env_id``. ``batch[i]`` is generated by environment
``env_ids[i]``.
blocking (bool): If ``True``, blocks if there is no free slot to add
data. If ``False``, enqueue can overwrite oldest data.
Returns:
True on success, False only in blocking mode when queue is stopped.
"""
if blocking:
assert self._allow_multiprocess, (
"Set allow_multiprocess to enable blocking mode.")
env_ids = self.check_convert_env_ids(env_ids)
while not self._stop.is_set():
with self._lock:
if self.has_space(env_ids):
self._enqueue(batch, env_ids)
return True
# The wait here is outside the lock, so multiple dequeue and
# enqueue could theoretically happen before the wait. The
# wait only acts as a more responsive sleep, and the return
# value is not used. We anyways need to check has_space after
# the wait timed out.
self._dequeued.wait(timeout=0.2)
return False
else:
self._enqueue(batch, env_ids)
return True
@atomic
def _enqueue(self, batch, env_ids=None):
"""Add a batch of items to the buffer (atomic).
Args:
batch (Tensor): shape should be
``[batch_size] + tensor_spec.shape``.
env_ids (Tensor): If ``None``, ``batch_size`` must be
``num_environments``. If not ``None``, its shape should be
``[batch_size]``. We assume there are no duplicate ids in
``env_id``. ``batch[i]`` is generated by environment
``env_ids[i]``.
"""
batch_size = alf.nest.get_nest_batch_size(batch)
with alf.device(self._device):
batch = convert_device(batch)
env_ids = self.check_convert_env_ids(env_ids)
assert batch_size == env_ids.shape[0], (
"batch and env_ids do not have same length %s vs. %s" %
(batch_size, env_ids.shape[0]))
# Make sure that there is no duplicate in `env_id`
# torch.unique(env_ids, return_counts=True)[1] is the counts for each unique item
assert torch.unique(
env_ids, return_counts=True)[1].max() == 1, (
"There are duplicated ids in env_ids %s" % env_ids)
current_pos = self._current_pos[env_ids]
indices = env_ids * self._max_length + self.circular(current_pos)
alf.nest.map_structure(
lambda buf, bat: buf.__setitem__(indices, bat.detach()),
self._flattened_buffer, batch)
self._current_pos[env_ids] = current_pos + 1
current_size = self._current_size[env_ids]
self._current_size[env_ids] = torch.clamp(
current_size + 1, max=self._max_length)
# set flags if they exist to unblock potential consumers
if self._enqueued:
self._enqueued.set()
self._dequeued.clear()
[docs] def has_data(self, env_ids, n=1):
"""Check ``n`` steps of data available for ``env_ids``.
Args:
env_ids (Tensor): Assumed not ``None``, properly checked by
``check_convert_env_ids()``.
n (int): Number of time steps to check.
Returns:
bool
"""
current_size = self._current_size[env_ids]
min_size = current_size.min()
return min_size >= n
[docs] def dequeue(self, env_ids=None, n=1, blocking=False):
"""Return earliest ``n`` steps and mark them removed in the buffer.
Args:
env_ids (Tensor): If None, ``batch_size`` must be num_environments.
If not None, dequeue from these environments. We assume there
is no duplicate ids in ``env_id``. ``result[i]`` will be from
environment ``env_ids[i]``.
n (int): Number of steps to dequeue.
blocking (bool): If ``True``, blocks if there is not enough data to
dequeue.
Returns:
nested Tensors or None when blocking dequeue gets terminated by
stop event. The shape of the Tensors is ``[batch_size, n, ...]``.
Raises:
AssertionError: when not enough data is present, in non-blocking
mode.
"""
assert n <= self._max_length
if blocking:
assert self._allow_multiprocess, [
"Set allow_multiprocess", "to enable blocking mode."
]
env_ids = self.check_convert_env_ids(env_ids)
while not self._stop.is_set():
with self._lock:
if self.has_data(env_ids, n):
return self._dequeue(env_ids=env_ids, n=n)
# The wait here is outside the lock, so multiple dequeue and
# enqueue could theoretically happen before the wait. The
# wait only acts as a more responsive sleep, and the return
# value is not used. We anyways need to check has_data after
# the wait timed out.
self._enqueued.wait(timeout=0.2)
return None
else:
return self._dequeue(env_ids=env_ids, n=n)
@atomic
def _dequeue(self, env_ids=None, n=1):
"""Return earliest ``n`` steps and mark them removed in the buffer.
Args:
env_ids (Tensor): If None, ``batch_size`` must be num_environments.
If not None, dequeue from these environments. We assume there
is no duplicate ids in ``env_id``. ``result[i]`` will be from
environment env_ids[i].
n (int): Number of steps to dequeue.
Returns:
nested Tensors of shape ``[batch_size, n, ...]``.
Raises:
AssertionError: when not enough data is present.
"""
with alf.device(self._device):
env_ids = self.check_convert_env_ids(env_ids)
current_size = self._current_size[env_ids]
min_size = current_size.min()
assert min_size >= n, (
"Not all environments have enough data. The smallest data "
"size is: %s Try storing more data before calling dequeue" %
min_size)
batch_size = env_ids.shape[0]
pos = self._current_pos[env_ids] - current_size # mod done later
b_indices = env_ids.reshape(batch_size, 1).expand(-1, n)
t_range = torch.arange(n).reshape(1, -1)
t_indices = self.circular(pos.reshape(batch_size, 1) + t_range)
result = alf.nest.map_structure(
lambda b: b[(b_indices, t_indices)], self._buffer)
self._current_size[env_ids] = current_size - n
# set flags if they exist to unblock potential consumers
if self._dequeued:
self._dequeued.set()
self._enqueued.clear()
return convert_device(result)
[docs] @atomic
def remove_up_to(self, n, env_ids=None):
"""Mark as removed earliest up to ``n`` steps.
Args:
n (int): max number of steps to mark removed from buffer.
"""
with alf.device(self._device):
env_ids = self.check_convert_env_ids(env_ids)
n = torch.min(
torch.as_tensor([n] * self._num_envs), self._current_size)
self._current_size[env_ids] = self._current_size[env_ids] - n
[docs] @atomic
def clear(self, env_ids=None):
"""Clear the buffer.
Args:
env_ids (Tensor): optional list of environment ids to clear
"""
with alf.device(self._device):
env_ids = self.check_convert_env_ids(env_ids)
self._current_size.scatter_(0, env_ids, 0)
self._current_pos.scatter_(0, env_ids, 0)
if self._dequeued:
self._dequeued.set()
self._enqueued.clear()
[docs] def stop(self):
"""Stop waiting processes from being blocked.
Only checked in blocking mode of dequeue and enqueue.
All blocking enqueue and dequeue calls that happen afterwards will
be skipped (return ``None`` for dequeue or ``False`` for enqueue),
unless the operation already started.
"""
self._stop.set()
[docs] def revive(self):
"""Clears the stop Event so blocking mode will start working again.
Only checked in blocking mode of dequeue and enqueue.
"""
self._stop.clear()
@property
def num_environments(self):
return self._num_envs
[docs] def get_earliest_position(self, env_ids):
"""The earliest position that is still in the replay buffer.
Args:
env_ids (Tensor): int64 Tensor of environment ids
Returns:
Tensor with the same shape as ``env_ids``, whose each entry is the
earliest position that is still in the replay buffer for
corresponding environment.
"""
return self._current_pos[env_ids] - self._current_size[env_ids]
[docs] def get_current_position(self):
"""Get the current position for each environment.
Returns:
Tensor: with shape [num_environments].
"""
return self._current_pos
[docs]class DataBuffer(RingBuffer):
"""A simple circular buffer supporting random sampling. This buffer doesn't
preserve temporality as data from multiple environments will be arbitrarily
stored.
Not multiprocessing safe.
"""
def __init__(self,
data_spec: TensorSpec,
capacity,
device='cpu',
name="DataBuffer"):
"""
Args:
data_spec (nested TensorSpec): spec for the data item (without batch
dimension) to be stored.
capacity (int): capacity of the buffer.
device (str): which device to store the data
name (str): name of the buffer
"""
super().__init__(
data_spec=data_spec,
num_environments=1,
max_length=capacity,
device=device,
allow_multiprocess=False,
name=name)
self._capacity = torch.as_tensor(
self._max_length, dtype=torch.int64, device=device)
self._derived_buffer = alf.nest.map_structure(lambda buf: buf[0],
self._buffer)
[docs] def add_batch(self, batch):
r"""Add a batch of items to the buffer.
Add batch_size items along the length of the underlying RingBuffer,
whereas RingBuffer.enqueue only adds data of length 1.
Truncates the data if ``batch_size > capacity``.
Args:
batch (Tensor): of shape ``[batch_size] + tensor_spec.shape``
"""
batch_size = alf.nest.get_nest_batch_size(batch)
with alf.device(self._device):
batch = convert_device(batch)
n = torch.clamp(self._capacity, max=batch_size)
current_pos = self.current_pos
current_size = self.current_size
indices = self.circular(torch.arange(current_pos, current_pos + n))
alf.nest.map_structure(
lambda buf, bat: buf.__setitem__(indices, bat[-n:].detach()),
self._derived_buffer, batch)
current_pos.copy_(current_pos + n)
current_size.copy_(torch.min(current_size + n, self._capacity))
[docs] def get_batch(self, batch_size):
r"""Get batsh_size random samples in the buffer.
Args:
batch_size (int): batch size
Returns:
Tensor of shape ``[batch_size] + tensor_spec.shape``
"""
with alf.device(self._device):
indices = torch.randint(
low=0,
high=self.current_size,
size=(batch_size, ),
dtype=torch.int64)
result = self.get_batch_by_indices(indices)
return convert_device(result)
[docs] def get_batch_by_indices(self, indices):
r"""Get the samples by indices
index=0 corresponds to the earliest added sample in the DataBuffer.
Args:
indices (Tensor): indices of the samples
Returns:
Tensor:
Tensor of shape ``[batch_size] + tensor_spec.shape``, where
``batch_size`` is ``indices.shape[0]``
"""
with alf.device(self._device):
indices = convert_device(indices)
indices = self.circular(indices + self.current_pos -
self.current_size)
result = alf.nest.map_structure(lambda buf: buf[indices],
self._derived_buffer)
return convert_device(result)
[docs] def is_full(self):
return (self.current_size == self._capacity).cpu().numpy()
@property
def current_size(self):
return self._current_size[0]
@property
def current_pos(self):
return self._current_pos[0]
[docs] def get_all(self):
return convert_device(
alf.nest.map_structure(lambda buf: buf, self._derived_buffer))