# Copyright (c) 2021 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
import functools
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
import alf
from alf.experience_replayers.replay_buffer import ReplayBuffer
class _MethodPerformer(torch.nn.Module):
"""A nn.Module wrapper whose forward() performs a specified method of
the wrapped module.
The end goal is to make a TARGET METHOD data distributed.
We need this delegation so that DDP can then wrap over this module. When DDP
hijacks the forward() of _MethodPerformer to inject synchronization hooks,
it effectively does so for the target method of the wrapped module.
"""
def __init__(self, module: torch.nn.Module, perform: Callable[..., Any]):
"""Constructs a _MethodPerformer.
Args:
module: an instance of the module whose method is going to be
delegated to. The _MethodPerformer instance needs to access and
inherit the parameters from the module, so that DDP knows what
parameters to cover.
perform: the target method of the module.
"""
super().__init__()
self._wrapped_module = module # Register and inherit the parameters
self._perform = functools.partial(perform, self._wrapped_module)
# DDP will panic if the wrapped module has member in its state_dict()
# that is not a Tensor. Here such state_dict members are picked and
# thrown into _ddp_params_and_buffers_to_ignore. By contract this
# implicitly instructs DDP wrapper to not include them in its
# parameter/buffer synchronization.
self._ddp_params_and_buffers_to_ignore = []
for name, value in self.state_dict().items():
if type(value) is not torch.Tensor:
self._ddp_params_and_buffers_to_ignore.append(name)
# We also need to ignore all the buffers that is under the replay buffer
# of the module (e.g. when the module is an Algorithm) for DDP, because
# we do not want DDP to synchronize replay buffers across processes.
#
# Those buffers are not registered in the state_dict() because of Alf's
# special treatment but can be found under named_buffers(). We do not
# want DDP to synchronize replay buffers.
ignored_named_buffers = set()
for sub_module in module.modules():
if isinstance(sub_module, ReplayBuffer):
for _, buf in sub_module.named_buffers():
# Find all the buffers that are registered under a
# ReplayBuffer submodule.
ignored_named_buffers.add(buf)
for name, buf in self.named_buffers():
# If the buffer is in the ignored_named_buffers (address-wise equal,
# i.e. ``is``), add its name to DDP's ignore list.
if buf in ignored_named_buffers:
self._ddp_params_and_buffers_to_ignore.append(name)
# TODO(breakds): In the future when needed, we can do explicit filtering
# if the wrapped module is an Algorithm. All parameters and buffers that
# are not within the optimizer can be added to ignore list.
def forward(self, *args, **kwargs):
return self._perform(*args, **kwargs)
[docs]def data_distributed(method):
"""This decorator makes a target method of a module capable of being data
distributed via DDP.
This is to provide a simple and transparent way to enable DDP for specific
code logics.
When the method is wrapped by @data_distributed, the outputs (tensors) of
this method will have gradient synchronization hooks attached to them. Later
when those outputs are used in ``backward()`` to compute gradients, the
hooks will be called to synchronize across all processes. As a result, the
corresponding parameters receive not only the gradients from this process,
but also gradients from the other processes. Note that each single process
will be TRAPPED at the call to the ``backward()`` that involves those output
tensors, until all processes finished the back propagation and have the
gradients sync'ed.
Example usage:
.. code-block:: python
class A(nn.Module):
# ...
@data_distributed
def compute_something(self, input):
return self._network1(input), self._network2(input)
# ...
In the above code, after applying the decorator, the method
``compute_something`` will be made data distributed if the following
conditions are met:
1. Multiple processes within the same process group creates A's instances
and calls ``compute_something()`` individually.
2. All such A instances have ``self._ddp_activated_rank`` set to the correct
rank of the GPU device that belongs to them.
Otherwise the method ``compute_something()`` will behave normally.
"""
return data_distributed_when(None)(method)
[docs]def data_distributed_when(
cond: Optional[Callable[[torch.nn.Module], bool]] = None):
"""This is @ data_distributed with an extra conditionon.
The condition is a function that returns True or False given the wrapped
module as the input. If the condition evaluates to False, DDP will not be
activated and the original method will be called.
"""
def decorator(method):
@functools.wraps(method)
def wrapped(*args, **kwargs):
# The first argument to the method is going to be ``self``, i.e. the
# instance that the method belongs to. By accessing it we get the
# reference of the module to wrap.
module_to_wrap = args[0]
assert isinstance(module_to_wrap, torch.nn.Module), (
f'Cannot apply @data_distributed on {type(module_to_wrap)}')
ddp_rank = getattr(module_to_wrap, '_ddp_activated_rank', -1)
# Evaluate the condition if it is provided.
if (cond is not None) and (not cond(module_to_wrap)):
ddp_rank = -1
# A ddp_rank of -1 means DDP is not activated for this module. In this
# case, just perform the normal method call.
if ddp_rank == -1:
return method(*args, **kwargs)
# Create a DDP wrapped _MethodPerformer instance if not yet. All the
# _MethodPerformer instances are registered in a map called
# _ddp_performer_map, which belongs to the module to wrap.
if not hasattr(module_to_wrap, '_ddp_performer_map'):
setattr(module_to_wrap, '_ddp_performer_map', {})
performer = module_to_wrap._ddp_performer_map.get(
method.__name__, None)
if performer is None:
performer = make_ddp_performer(module_to_wrap, method,
ddp_rank)
module_to_wrap._ddp_performer_map[method.__name__] = performer
return performer(*args[1:], **kwargs)
return wrapped
return decorator