Source code for alf.utils.per_process_context
# Copyright (c) 2021 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.multiprocessing as mp
[docs]class PerProcessContext(object):
"""A singletone that maintains the per process runtime properties.
It is used mainly in multi-process distributed training mode,
where properties such as the rank of the process and the total
number of processes can be accessed via this interface.
"""
_instance = None
def __new__(cls):
"""Construct the singleton instance.
This initializes the singleton and default values are assigned
to the properties.
"""
if cls._instance is None:
cls._instance = super(PerProcessContext, cls).__new__(cls)
cls._instance._read_only = False
cls._instance._ddp_rank = -1
cls._instance._num_processes = 1
return cls._instance
[docs] def finalize(self) -> None:
"""Lock the context so that it becomes read only.
"""
self._read_only = True
[docs] def set_distributed(self, rank: int, num_processes: int) -> None:
"""Set the distributed properties.
Args:
rank (int): the ID of the process
num_processes (int): the total number of processes
"""
if self._read_only:
raise AttributeError(
'Cannot mutate PerProcessContext after it is finalized')
self._ddp_rank = rank
self._num_processes = num_processes
[docs] def set_paras_queue(self, paras_queue: mp.Queue):
"""Set the parameter queue.
The queue is used for checking the consistency of model parameters across
different worker processes, if multi-gpu training is used.
"""
if self._read_only:
raise AttributeError(
'Cannot mutate PerProcessContext after it is finalized')
self._paras_queue = paras_queue
@property
def paras_queue(self) -> mp.Queue:
return self._paras_queue
@property
def is_distributed(self):
return self._ddp_rank >= 0
@property
def ddp_rank(self):
return self._ddp_rank
@property
def num_processes(self):
return self._num_processes