# Copyright (c) 2019 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Various functions used by different alf modules."""
from absl import flags
from absl import logging
import contextlib
import copy
from fasteners.process_lock import InterProcessLock
from functools import wraps
import gin
import glob
import math
import numpy as np
import os
import pathlib
import pprint
import random
import shutil
import socket
import subprocess
import sys
import time
import torch
import torch.distributions as td
import torch.nn as nn
import traceback
import types
from typing import Callable, List, Dict
import alf
from alf.algorithms.config import TrainerConfig
import alf.nest as nest
from alf.tensor_specs import TensorSpec, BoundedTensorSpec
from alf.utils.spec_utils import zeros_from_spec as zero_tensor_from_nested_spec
from alf.utils.per_process_context import PerProcessContext
from . import dist_utils, gin_utils
[docs]def add_method(cls):
"""A decorator for adding a method to a class (cls).
Example usage:
.. code-block:: python
class A:
pass
@add_method(A)
def new_method(self):
print('new method added')
# now new_method() is added to class A and is ready to be used
a = A()
a.new_method()
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
setattr(cls, func.__name__, wrapper)
return func
return decorator
[docs]def as_list(x):
"""Convert ``x`` to a list.
It performs the following conversion:
.. code-block:: python
None => []
list => x
tuple => list(x)
other => [x]
Args:
x (any): the object to be converted
Returns:
list:
"""
if x is None:
return []
if isinstance(x, list):
return x
if isinstance(x, tuple):
return list(x)
return [x]
[docs]def tuplify2d(x):
"""Convert ``x`` to a tuple of length two.
It performs the following conversion:
.. code-block:: python
x => x if isinstance(x, tuple) and len(x) == 2
x => (x, x) if not isinstance(x, tuple)
Args:
x (any): the object to be converted
Returns:
tuple:
"""
if isinstance(x, tuple):
assert len(x) == 2
return x
return (x, x)
[docs]class Periodically(nn.Module):
def __init__(self, body, period, name='periodically'):
"""Periodically performs the operation defined in body.
Args:
body (Callable): callable to be performed every time
an internal counter is divisible by the period.
period (int): inverse frequency with which to perform the operation.
name (str): name of the object.
Raises:
TypeError: if body is not a callable.
"""
super().__init__()
if not callable(body):
raise TypeError('body must be callable.')
self._body = body
self._period = period
self._counter = 0
self._name = name
[docs] def forward(self):
self._counter += 1
if self._counter % self._period == 0:
self._body()
elif self._period is None:
return
[docs]@alf.configurable
class TargetUpdater(nn.Module):
r"""Performs a soft update of the target model parameters.
For each weight :math:`w_s` in the model, and its corresponding
weight :math:`w_t` in the target_model, a soft update is:
.. math::
w_t = (1 - \tau) * w_t + \tau * w_s.
Note: we only perform soft updates for parameters and always copy buffers.
Args:
models (Network | list[Network] | Parameter | list[Parameter] ): the
current model or parameter.
target_models (Network | list[Network] | Parameter | list[Parameter]):
the model or parameter to be updated.
tau (float): A float scalar in :math:`[0, 1]`. Default :math:`\tau=1.0`
means hard update.
period (int): Step interval at which the target model is updated.
init_copy (bool): If True, also copy ``models`` to ``target_models`` in the
beginning.
delayed_update: if True, ``target_models`` is updated using recent_models
every ``period`` steps. If ``tau`` is 1, the recent_models is ``models``
``period`` steps before. If ``tau`` is not 1, recent_models is
an exponential moving average of ``models`` with rate ``tau``.
The use of delayed_update may help to improve the stability of TD
learning when a small ``period`` is used.
"""
def __init__(self,
models,
target_models,
tau=1.0,
period=1,
init_copy=True,
delayed_update: bool = False):
super().__init__()
models = as_list(models)
target_models = as_list(target_models)
assert len(models) == len(target_models), (
"The length of models and "
"target_models are different: %s vs. %s" % (len(models),
len(target_models)))
for model, target_model in zip(models, target_models):
self._validate(model, target_model)
self._models = models
self._target_models = target_models
if delayed_update:
self._recent_models = list(
map(self._make_copy, models, target_models))
self._tau = tau
self._period = period
self._delayed_update = delayed_update
self._counter = 0
if init_copy:
for model, target_model in zip(models, target_models):
self._copy_model_or_parameter(model, target_model)
def _make_copy(self, s, t):
if isinstance(s, nn.Parameter):
if id(s) == id(t):
return s
else:
return copy.deepcopy(s)
else:
module = nn.ParameterList()
for ws, wt in zip(s.parameters(), t.parameters()):
if id(ws) == id(wt):
module.append(ws)
else:
module.append(copy.deepcopy(ws))
for i, (ws, wt) in enumerate(zip(s.buffers(), t.buffers())):
if id(ws) == id(wt):
module.register_buffer("b%s" % i, ws)
else:
module.register_buffer("b%s" % i, copy.deepcopy(ws))
return module
def _validate(self, s, t):
def _error_msg(ns, nt):
return ("The corresponding parameter/buffer of the source model "
"and the target model have different name: %s vs %s" %
(ns, nt))
def _warning_msg(n):
warning(
"The corresponding parameter/buffer %s of the source model "
"and the target model are same object. They will be ignored by "
"TargetUpdater." % n)
if isinstance(s, nn.Parameter):
if id(s) == id(t):
warning("target and the source parameter are same object. It "
"will be ignored by the TargetUpdater.")
else:
sparams = list(s.named_parameters())
tparams = list(t.named_parameters())
assert len(sparams) == len(tparams), (
"The source model and the "
"target models have different number of parameters: %s vs. %s"
% (len(sparams), len(tparams)))
for (ns, ws), (nt, wt) in zip(sparams, tparams):
assert ns == nt, _error_msg(ns, nt)
if id(ws) == id(wt):
_warning_msg(ns)
sbuffers = list(s.named_buffers())
tbuffers = list(t.named_buffers())
assert len(sbuffers) == len(tbuffers), (
"The source model and the "
"target models have different number of buffers: %s vs. %s" %
(len(sbuffers), len(tbuffers)))
for (ns, ws), (nt, wt) in zip(sbuffers, tbuffers):
assert ns == nt, _error_msg(ns, nt)
if id(ws) == id(wt):
_warning_msg(ns)
def _copy_model_or_parameter(self, s, t):
if isinstance(s, nn.Parameter):
if id(s) != id(t):
t.data.copy_(s)
else:
for ws, wt in zip(s.parameters(), t.parameters()):
if id(ws) != id(wt):
wt.data.copy_(ws)
for ws, wt in zip(s.buffers(), t.buffers()):
if id(ws) != id(wt):
wt.copy_(ws)
def _lerp_model_or_parameter(self, s, t):
if isinstance(s, nn.Parameter):
if id(s) != id(t):
t.data.lerp_(s, self._tau)
else:
for ws, wt in zip(s.parameters(), t.parameters()):
if id(ws) != id(wt):
wt.data.lerp_(ws, self._tau)
for ws, wt in zip(s.buffers(), t.buffers()):
if id(ws) != id(wt):
wt.copy_(ws)
[docs] def forward(self):
self._counter += 1
if self._counter % self._period == 0:
if self._delayed_update:
for model, target_model in zip(self._recent_models,
self._target_models):
self._copy_model_or_parameter(model, target_model)
elif self._tau != 1.0:
for model, target_model in zip(self._models,
self._target_models):
self._lerp_model_or_parameter(model, target_model)
else:
for model, target_model in zip(self._models,
self._target_models):
self._copy_model_or_parameter(model, target_model)
if self._delayed_update:
if self._tau != 1.0:
for model, target_model in zip(self._models,
self._recent_models):
self._lerp_model_or_parameter(model, target_model)
elif self._counter % self._period == 0:
for model, target_model in zip(self._models,
self._recent_models):
self._copy_model_or_parameter(model, target_model)
[docs]def expand_dims_as(x, y, end=True):
"""Expand the shape of ``x`` with extra singular dimensions.
The result is broadcastable to the shape of ``y``.
Args:
x (Tensor): source tensor
y (Tensor): target tensor. Only its shape will be used.
end (bool): If True, the extra dimensions are at the end of ``x``;
otherwise they are at the beginning.
Returns:
``x`` with extra singular dimensions.
"""
assert x.ndim <= y.ndim
k = y.ndim - x.ndim
if k == 0:
return x
else:
if end:
assert x.shape == y.shape[:x.ndim]
return x.reshape(*x.shape, *([1] * k))
else:
assert x.shape == y.shape[k:]
return x.reshape(*([1] * k), *x.shape)
[docs]def reset_state_if_necessary(state, initial_state, reset_mask):
"""Reset state to initial state according to ``reset_mask``.
Args:
state (nested Tensor): the current batched states
initial_state (nested Tensor): batched intitial states
reset_mask (nested Tensor): with ``shape=(batch_size,), dtype=torch.bool``
Returns:
nested Tensor
"""
if torch.any(reset_mask):
return alf.nest.map_structure(
lambda i_s, s: torch.where(
expand_dims_as(reset_mask, i_s), i_s.to(s.dtype), s),
initial_state, state)
else:
return state
[docs]def run_under_record_context(func,
summary_dir,
summary_interval,
flush_secs,
summarize_first_interval=True,
summary_max_queue=10):
"""Run ``func`` under summary record context.
Args:
func (Callable): the function to be executed.
summary_dir (str): directory to store summary. A directory starting with
``~/`` will be expanded to ``$HOME/``.
summary_interval (int): how often to generate summary based on the
global counter
flush_secs (int): flush summary to disk every so many seconds
summarize_first_interval (bool): whether to summarize every step of
the first interval (default True). It might be better to turn
this off for an easier post-processing of the curve.
summary_max_queue (int): the largest number of summaries to keep in a queue;
will flush once the queue gets bigger than this. Defaults to 10.
"""
# Disable summary if in distributed mode and the running process isn't the
# master process (i.e. rank = 0)
if PerProcessContext().ddp_rank > 0:
func()
return
summary_dir = os.path.expanduser(summary_dir)
summary_writer = alf.summary.create_summary_writer(
summary_dir, flush_secs=flush_secs, max_queue=summary_max_queue)
global_step = alf.summary.get_global_counter()
def _cond():
# We always write summary in the initial `summary_interval` steps
# because there might be important changes at the beginning.
return (alf.summary.is_summary_enabled() and
((global_step < summary_interval and summarize_first_interval)
or global_step % summary_interval == 0))
with alf.summary.push_summary_writer(summary_writer):
with alf.summary.record_if(_cond):
func()
summary_writer.close()
def _markdownify_gin_config_str(string, description=''):
"""Convert an gin config string to markdown format.
Args:
string (str): the string from ``gin.operative_config_str()``.
description (str): Optional long-form description for this config str.
Returns:
string: the markdown version of the config string.
"""
# This function is from gin.tf.utils.GinConfigSaverHook
# TODO: Total hack below. Implement more principled formatting.
def _process(line):
"""Convert a single line to markdown format."""
if not line.startswith('#'):
return ' ' + line
line = line[2:]
if line.startswith('===='):
return ''
if line.startswith('None'):
return ' # None.'
if line.endswith(':'):
return '#### ' + line
return line
output_lines = []
if description:
output_lines.append(" # %s\n" % description)
for line in string.splitlines():
procd_line = _process(line)
if procd_line is not None:
output_lines.append(procd_line)
return '\n'.join(output_lines)
[docs]def get_gin_confg_strs():
"""
Obtain both the operative and inoperative config strs from gin.
The operative configuration consists of all parameter values used by
configurable functions that are actually called during execution of the
current program, and inoperative configuration consists of all parameter
configured but not used by configurable functions. See
``gin.operative_config_str()`` and ``gin_utils.inoperative_config_str`` for
more detail on how the config is generated.
Returns:
tuple:
- md_operative_config_str (str): a markdown-formatted operative str
- md_inoperative_config_str (str): a markdown-formatted inoperative str
"""
operative_config_str = gin.operative_config_str()
md_operative_config_str = _markdownify_gin_config_str(
operative_config_str,
'All parameter values used by configurable functions that are actually called'
)
md_inoperative_config_str = gin_utils.inoperative_config_str()
if md_inoperative_config_str:
md_inoperative_config_str = _markdownify_gin_config_str(
md_inoperative_config_str,
"All parameter values configured but not used by program. The configured "
"functions are either not called or called with explicit parameter values "
"overriding the config.")
return md_operative_config_str, md_inoperative_config_str
[docs]def summarize_gin_config():
"""Write the operative and inoperative gin config to Tensorboard summary.
"""
md_operative_config_str, md_inoperative_config_str = get_gin_confg_strs()
alf.summary.text('gin/operative_config', md_operative_config_str)
if md_inoperative_config_str:
alf.summary.text('gin/inoperative_config', md_inoperative_config_str)
[docs]def copy_gin_configs(root_dir, gin_files):
"""Copy gin config files to root_dir
Args:
root_dir (str): directory path
gin_files (None|list[str]): list of file paths
"""
root_dir = os.path.expanduser(root_dir)
os.makedirs(root_dir, exist_ok=True)
for f in gin_files:
shutil.copyfile(f, os.path.join(root_dir, os.path.basename(f)))
[docs]def get_gin_file():
"""Get the gin configuration file.
If ``FLAGS.gin_file`` is not set, find gin files under ``FLAGS.root_dir`` and
returns them. If there is no 'gin_file' flag defined, return ''.
Returns:
the gin file(s)
"""
if hasattr(flags.FLAGS, "gin_file"):
gin_file = flags.FLAGS.gin_file
if gin_file is None:
root_dir = os.path.expanduser(flags.FLAGS.root_dir)
gin_file = glob.glob(os.path.join(root_dir, "*.gin"))
assert gin_file, "No gin files are found! Please provide"
return gin_file
else:
return ''
ALF_CONFIG_FILE = 'alf_config.py'
[docs]def get_conf_file(root_dir=None):
"""Get the configuration file.
If ``FLAGS.conf`` is not set, find alf_config.py or configured.gin under
``FLAGS.root_dir`` and returns it. If there is no 'conf' flag defined,
return None.
Args:
root_dir (str): when None, FLAGS.root_dir is used to find the conf file.
Returns:
str: the name of the conf file. None if there is no conf file
"""
if not hasattr(flags.FLAGS, "conf") and not hasattr(
flags.FLAGS, "gin_file"):
return None
conf_file = getattr(flags.FLAGS, 'conf', None)
if conf_file is not None:
return conf_file
conf_file = getattr(flags.FLAGS, 'gin_file', None)
if conf_file is not None:
return conf_file
if root_dir is None:
root_dir = os.path.expanduser(flags.FLAGS.root_dir)
conf_file = os.path.join(root_dir, ALF_CONFIG_FILE)
if os.path.exists(conf_file):
return conf_file
gin_file = glob.glob(os.path.join(root_dir, "*.gin"))
if not gin_file:
return None
assert len(
gin_file) == 1, "Multiple *.gin files are found in %s" % root_dir
return gin_file[0]
[docs]def parse_conf_file(conf_file):
"""Parse config from file.
It also looks for FLAGS.gin_param and FLAGS.conf_param for extra configs.
Note: a global environment will be created (which can be obtained by
alf.get_env()) and random seed will be initialized by this function using
common.set_random_seed().
Args:
conf_file (str): the full path to the config file
"""
if conf_file.endswith(".gin"):
gin_params = getattr(flags.FLAGS, 'gin_param', None)
gin.parse_config_files_and_bindings([conf_file], gin_params)
ml_type = alf.get_config_value('TrainerConfig.ml_type')
if ml_type == 'rl':
# Create the global environment and initialize random seed
alf.get_env()
else:
conf_params = getattr(flags.FLAGS, 'conf_param', None)
alf.parse_config(conf_file, conf_params)
[docs]def get_epsilon_greedy(config: TrainerConfig):
if config is not None:
return config.epsilon_greedy
else:
return alf.get_config_value('TrainerConfig.epsilon_greedy')
[docs]def summarize_config():
"""Write config to TensorBoard."""
def _format(configs):
paragraph = pprint.pformat(dict(configs))
return " ".join((os.linesep + paragraph).splitlines(keepends=True))
conf_file = get_conf_file()
if conf_file is None or conf_file.endswith('.gin'):
return summarize_gin_config()
operative_configs = alf.get_operative_configs()
inoperative_configs = alf.get_inoperative_configs()
alf.summary.text('config/operative_config', _format(operative_configs))
if inoperative_configs:
alf.summary.text('config/inoperative_config',
_format(inoperative_configs))
[docs]def read_conf_file(root_dir: str) -> str:
"""Read the content of the conf file.
Args:
root_dir: alf log directory path
Returns:
the content of the conf file as a str. ``None`` if conf file is not
specified through commandline and cannot be found in root_dir
"""
conf_file = get_conf_file()
if conf_file is None:
return None
with open(conf_file, 'r') as f:
content = f.read()
return content
[docs]def write_config(root_dir: str):
"""Write config to a file under directory ``root_dir``
Configs from FLAGS.conf_param are also recorded.
Args:
root_dir: directory path
"""
conf_file = get_conf_file()
if conf_file is None or conf_file.endswith('.gin'):
return write_gin_configs(root_dir, 'configured.gin')
root_dir = os.path.expanduser(root_dir)
alf.save_config(os.path.join(root_dir, ALF_CONFIG_FILE))
[docs]def get_initial_policy_state(batch_size, policy_state_spec):
"""
Return zero tensors as the initial policy states.
Args:
batch_size (int): number of policy states created
policy_state_spec (nested structure): each item is a tensor spec for
a state
Returns:
state (nested structure): each item is a tensor with the first dim equal
to ``batch_size``. The remaining dims are consistent with
the corresponding state spec of ``policy_state_spec``.
"""
return zero_tensor_from_nested_spec(policy_state_spec, batch_size)
[docs]def get_initial_time_step(env, first_env_id=0):
"""Return the initial time step.
Args:
env (AlfEnvironment):
first_env_id (int): the environment ID for the first sample in this
batch.
Returns:
TimeStep: the init time step with actions as zero tensors.
"""
time_step = env.current_time_step()
return time_step._replace(env_id=time_step.env_id + first_env_id)
_env = None
[docs]def set_global_env(env):
"""Set global env."""
global _env
_env = env
[docs]@alf.configurable
def get_raw_observation_spec(field=None):
"""Get the ``TensorSpec`` of observations provided by the global environment.
Args:
field (str): a multi-step path denoted by "A.B.C".
Returns:
nested TensorSpec: a spec that describes the observation.
"""
assert _env, "set a global env by `set_global_env` before using the function"
specs = _env.observation_spec()
if field:
for f in field.split('.'):
specs = specs[f]
return specs
_transformed_observation_spec = None
[docs]@alf.configurable
def get_observation_spec(field=None):
"""Get the spec of observation transformed by data transformers.
The data transformers are specified by ``TrainerConfig.data_transformer_ctor``.
Args:
field (str): a multi-step path denoted by "A.B.C".
Returns:
nested TensorSpec: a spec that describes the observation.
"""
assert _transformed_observation_spec is not None, (
"This function should be "
"called after the global variable _transformed_observation_spec is set"
)
specs = _transformed_observation_spec
if field:
for f in field.split('.'):
specs = specs[f]
return specs
[docs]@alf.configurable
def get_states_shape():
"""Get the tensor shape of internal states of the agent provided by
the global environment.
Returns:
0 if internal states is not part of observation; otherwise a
``torch.Size``. We don't raise error so this code can serve to check
whether ``env`` has states input.
"""
assert _env, "set a global env by `set_global_env` before using the function"
if isinstance(_env.observation_spec(),
dict) and ('states' in _env.observation_spec()):
return _env.observation_spec()['states'].shape
else:
return 0
[docs]@alf.configurable
def get_action_spec():
"""Get the specs of the tensors expected by ``step(action)`` of the global
environment.
Returns:
nested TensorSpec: a spec that describes the shape and dtype of each tensor
expected by ``step()``.
"""
assert _env, "set a global env by `set_global_env` before using the function"
return _env.action_spec()
[docs]@alf.configurable
def get_reward_spec():
"""Get the specs of the reward tensors of the global environment.
Returns:
nested TensorSpec: a spec that describes the shape and dtype of each reward
tensor.
"""
assert _env, "set a global env by `set_global_env` before using the function"
return _env.reward_spec()
[docs]def get_env():
assert _env, "set a global env by `set_global_env` before using the function"
return _env
[docs]@alf.configurable
def get_vocab_size():
"""Get the vocabulary size of observations provided by the global environment.
Returns:
int: size of the environment's/teacher's vocabulary. Returns 0 if
language is not part of observation. We don't raise error so this code
can serve to check whether the env has language input
"""
assert _env, "set a global env by `set_global_env` before using the function"
if isinstance(_env.observation_spec(),
dict) and ('sentence' in _env.observation_spec()):
# return _env.observation_spec()['sentence'].shape[0]
# is the sequence length of the sentence.
return _env.observation_spec()['sentence'].maximum + 1
else:
return 0
[docs]@alf.configurable
def active_action_target_entropy(active_action_portion=0.2, min_entropy=0.3):
"""Automatically compute target entropy given the action spec. Currently
support discrete actions only.
The general idea is that we assume :math:`Nk` actions having uniform probs
for a good policy. Thus the target entropy should be :math:`log(Nk)`, where
:math:`N` is the total number of discrete actions and k is the active action
portion.
TODO: incorporate this function into ``EntropyTargetAlgorithm`` if it proves
to be effective.
Args:
active_action_portion (float): a number in :math:`(0, 1]`. Ideally, this
value should be greater than ``1/num_actions``. If it's not, it will
be ignored.
min_entropy (float): the minimum possible entropy. If the auto-computed
entropy is smaller than this value, then it will be replaced.
Returns:
float: the target entropy for ``EntropyTargetAlgorithm``.
"""
assert active_action_portion <= 1.0 and active_action_portion > 0
action_spec = get_action_spec()
assert action_spec.is_discrete(
action_spec), "only support discrete actions!"
num_actions = action_spec.maximum - action_spec.minimum + 1
return max(math.log(num_actions * active_action_portion), min_entropy)
[docs]def write_gin_configs(root_dir, gin_file):
"""
Write a gin configration to a file. Because the user can
1) manually change the gin confs after loading a conf file into the code, or
2) include a gin file in another gin file while only the latter might be
copied to ``root_dir``.
So here we just dump the actual used gin conf string to a file.
Args:
root_dir (str): directory path
gin_file (str): a single file path for storing the gin configs. Only
the basename of the path will be used.
"""
root_dir = os.path.expanduser(root_dir)
os.makedirs(root_dir, exist_ok=True)
file = os.path.join(root_dir, os.path.basename(gin_file))
md_operative_config_str, md_inoperative_config_str = get_gin_confg_strs()
config_str = md_operative_config_str + '\n\n' + md_inoperative_config_str
# the mark-down string can just be safely written as a python file
with open(file, "w") as f:
f.write(config_str)
[docs]@logging.skip_log_prefix
def warning_once(msg, *args):
"""Generate warning message ``msg % args`` once.
Note that the current implementation resembles that of the ``log_every_n()```
function in ``logging`` but reduces the calling stack by one to ensure
the multiple warning once messages generated at difference places can be
displayed correctly.
Args:
msg: str, the message to be logged.
*args: The args to be substitued into the msg.
"""
caller = logging.get_absl_logger().findCaller()
count = logging._get_next_log_count_per_token(caller)
logging.log_if(logging.WARNING, "\033[1;31m" + msg + "\033[1;0m",
count == 0, *args)
[docs]@logging.skip_log_prefix
def warning(msg, *args):
"""Generate warning message ``msg % args``.
Args:
msg: str, the message to be logged.
*args: The args to be substitued into the msg.
"""
logging.log(logging.WARNING, "\033[1;31m" + msg + "\033[1;0m", *args)
[docs]@logging.skip_log_prefix
def info(msg, *args):
"""Generate info message ``msg % args``.
Args:
msg: str, the message to be logged.
*args: The args to be substitued into the msg.
"""
logging.log(logging.INFO, "\033[1;34m" + msg + "\033[1;0m", *args)
[docs]@logging.skip_log_prefix
def info_once(msg, *args):
"""Generate info message ``msg % args`` once.
Args:
msg: str, the message to be logged.
*args: The args to be substitued into the msg.
"""
caller = logging.get_absl_logger().findCaller()
count = logging._get_next_log_count_per_token(caller)
logging.log_if(logging.INFO, "\033[1;34m" + msg + "\033[1;0m", count == 0,
*args)
[docs]def set_random_seed(seed):
"""Set a seed for deterministic behaviors.
Note: If someone runs an experiment with a pre-selected manual seed, he can
definitely reproduce the results with the same seed; however, if he runs the
experiment with seed=None and re-run the experiments using the seed previously
returned from this function (e.g. the returned seed might be logged to
Tensorboard), and if cudnn is used in the code, then there is no guarantee
that the results will be reproduced with the recovered seed.
Args:
seed (int|None): seed to be used. If None, a default seed based on
pid and time will be used.
Returns:
The seed being used if ``seed`` is None.
"""
if seed is None:
seed = abs(hash(str(os.getpid()) + '|' + str(time.time())))
else:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
force_torch_deterministic = getattr(flags.FLAGS,
'force_torch_deterministic', True)
# causes RuntimeError: scatter_add_cuda_kernel does not have a deterministic implementation
torch.use_deterministic_algorithms(force_torch_deterministic)
seed %= 2**32
random.seed(seed)
# sometime the seed passed in can be very big, but np.random.seed
# only accept seed smaller than 2**32
np.random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
return seed
[docs]def log_metrics(metrics, prefix=''):
"""Log metrics through logging.
Args:
metrics (list[alf.metrics.StepMetric]): list of metrics to be logged
prefix (str): prefix to the log segment
"""
log = ['{0} = {1}'.format(m.name, m.result()) for m in metrics]
logging.info('%s \n\t\t %s', prefix, '\n\t\t '.join(log))
[docs]def create_ou_process(action_spec, ou_stddev, ou_damping):
"""Create nested zero-mean Ornstein-Uhlenbeck processes.
The temporal update equation is:
.. code-block:: python
x_next = (1 - damping) * x + N(0, std_dev)
Note: if ``action_spec`` is nested, the returned nested OUProcess will not bec
checkpointed.
Args:
action_spec (nested BountedTensorSpec): action spec
ou_damping (float): Damping rate in the above equation. We must have
:math:`0 <= damping <= 1`.
ou_stddev (float): Standard deviation of the Gaussian component.
Returns:
nested ``OUProcess`` with the same structure as ``action_spec``.
"""
def _create_ou_process(action_spec):
return dist_utils.OUProcess(action_spec.zeros(), ou_damping, ou_stddev)
ou_process = alf.nest.map_structure(_create_ou_process, action_spec)
return ou_process
[docs]def detach(nests: alf.nest.Nest):
"""Detach nested Tensors or Distributions
Args:
nests: tensors or distributions to be detached
Returns:
detached Tensors/Distributions with same structure as nests
"""
def _detach_dist_or_tensor(dist_or_tensor):
if isinstance(dist_or_tensor, td.Distribution):
builder, params = dist_utils._get_builder(dist_or_tensor)
return builder(**detach(params))
else:
return dist_or_tensor.detach()
return nest.map_structure(_detach_dist_or_tensor, nests)
# A catch all mode. Currently includes on-policy training on unrolled experience.
EXE_MODE_OTHER = 0
# Unroll during training
EXE_MODE_ROLLOUT = 1
# Replay, policy evaluation on experience and training
EXE_MODE_REPLAY = 2
# Evaluation / testing or playing a learned model
EXE_MODE_EVAL = 3
# pretrain mode
EXE_MODE_PRETRAIN = 4
# Global execution mode to track where the program is in the RL training process.
# This is used currently for observation normalization to only update statistics
# during training (vs unroll). This is also used in tensorboard plotting of
# network output values, evaluation of the same network during rollout vs eval vs
# replay will be plotted to different graphs.
_exe_mode = EXE_MODE_OTHER
_exe_mode_strs = ["other", "rollout", "replay", "eval", "pretrain"]
[docs]def set_exe_mode(mode):
"""Mark whether the current code belongs to unrolling or training. This flag
might be used to change the behavior of some functions accordingly.
Args:
training (bool): True for training, False for unrolling
Returns:
the old exe mode
"""
global _exe_mode
old = _exe_mode
_exe_mode = mode
return old
[docs]def exe_mode_name():
"""return the execution mode as string.
"""
return _exe_mode_strs[_exe_mode]
[docs]def is_replay():
"""Return a bool value indicating whether the current code belongs to
replaying. Replaying implies off-policy training.
Any code under ``train_from_replay_buffer()`` of any algorithm is classified
as replaying. This phase starts from experience sampling from the replay buffer,
all the way to the parameter update.
"""
return _exe_mode == EXE_MODE_REPLAY
[docs]def is_rollout():
"""Return a bool value indicating whether the current code belongs to
unrolling. For on-policy algorithms, unrolling could be treated as part of
training as it usually generates training info for calculating the loss.
Any code under ``unroll()`` of the root RL algorithm is classified as unrolling.
This is the phase of collecting experiences for training.
"""
return _exe_mode == EXE_MODE_ROLLOUT
[docs]def is_eval():
"""Return a bool value indicating whether the current code belongs to
evaluation or playing a learned model.
"""
return _exe_mode == EXE_MODE_EVAL
[docs]def is_pretrain():
"""Return a bool value indicating whether the current code belongs to
pre-train. The code within a function that is decorated by ``mark_pretrain``
is flagged as ``pretrain``. A code block that is within a ``pretrain_context``
is also flagged as ``pretrain``.
"""
return _exe_mode == EXE_MODE_PRETRAIN
[docs]def is_training(alg):
"""Return a bool value indicating whether the current code is in a training
phase, for either an on-policy or an off-policy algorithm.
A training phase is defined as the rollout phase for an on-policy algorithm,
or the replay phase for an off-policy algorithm.
.. note::
Currently this function returns False for the code under ``train_from_unroll()``.
Args:
alg (Algorithm): the algorithm to be decided
"""
return (alg.on_policy and is_rollout()) or is_replay()
[docs]def mark_eval(func):
"""A decorator that will automatically mark the ``_exe_mode`` flag when
entering/exiting a evaluation/test function.
Args:
func (Callable): a function
"""
def _func(*args, **kwargs):
old_mode = _exe_mode
set_exe_mode(EXE_MODE_EVAL)
ret = func(*args, **kwargs)
set_exe_mode(old_mode)
return ret
return _func
[docs]def mark_replay(func):
"""A decorator that will automatically mark the ``_exe_mode`` flag when
entering/exiting a experience replay function.
Args:
func (Callable): a function
"""
def _func(*args, **kwargs):
old_mode = _exe_mode
set_exe_mode(EXE_MODE_REPLAY)
ret = func(*args, **kwargs)
set_exe_mode(old_mode)
return ret
return _func
[docs]def mark_rollout(func):
"""A decorator that will automatically mark the ``_exe_mode`` flag when
entering/exiting a rollout function.
Args:
func (Callable): a function
"""
def _func(*args, **kwargs):
old_mode = _exe_mode
set_exe_mode(EXE_MODE_ROLLOUT)
ret = func(*args, **kwargs)
set_exe_mode(old_mode)
return ret
return _func
[docs]def mark_pretrain(func):
"""A decorator that will automatically mark the ``_exe_mode`` flag when
entering/exiting a pretrain function.
Args:
func (Callable): a function
"""
def _func(*args, **kwargs):
old_mode = _exe_mode
set_exe_mode(EXE_MODE_PRETRAIN)
ret = func(*args, **kwargs)
set_exe_mode(old_mode)
return ret
return _func
[docs]class eval_context(object):
"""A context manager that will automatically mark the ``_exe_mode`` flag
as ``EXE_MODE_EVAL`` when entering a context and revert to the original
``_exe_mode`` when exiting the context.
"""
def __init__(self):
self._old_mode = _exe_mode
def __enter__(self):
set_exe_mode(EXE_MODE_EVAL)
def __exit__(self, type, value, traceback):
set_exe_mode(self._old_mode)
return True
[docs]class replay_context(object):
"""A context manager that will automatically mark the ``_exe_mode`` flag
as ``EXE_MODE_REPLAY`` when entering a context and revert to the original
``_exe_mode`` when exiting the context.
"""
def __init__(self):
self._old_mode = _exe_mode
def __enter__(self):
set_exe_mode(EXE_MODE_REPLAY)
def __exit__(self, type, value, traceback):
set_exe_mode(self._old_mode)
return True
[docs]class rollout_context(object):
"""A context manager that will automatically mark the ``_exe_mode`` flag
as ``EXE_MODE_ROLLOUT`` when entering a context and revert to the original
``_exe_mode`` when exiting the context.
"""
def __init__(self):
self._old_mode = _exe_mode
def __enter__(self):
set_exe_mode(EXE_MODE_ROLLOUT)
def __exit__(self, type, value, traceback):
set_exe_mode(self._old_mode)
return True
[docs]class pretrain_context(object):
"""A context manager that will automatically mark the ``_exe_mode`` flag
as ``EXE_MODE_PRETRAIN`` when entering a context and revert to the original
``_exe_mode`` when exiting the context.
"""
def __init__(self):
self._old_mode = _exe_mode
def __enter__(self):
set_exe_mode(EXE_MODE_PRETRAIN)
def __exit__(self, type, value, traceback):
set_exe_mode(self._old_mode)
return True
[docs]@alf.configurable
def flattened_size(spec):
"""Return the size of the vector if spec.shape is flattened.
It's same as np.prod(spec.shape)
Args:
spec (alf.TensorSpec): a TensorSpec object
Returns:
np.int64: the size of flattened shape
"""
# np.prod(()) == 1.0, need to convert to np.int64
return np.int64(np.prod(spec.shape))
[docs]def is_inside_docker_container():
"""Return whether the current process is running inside a docker container.
See discussions at `<https://stackoverflow.com/questions/23513045/how-to-check-if-a-process-is-running-inside-docker-container>`_
"""
return os.path.exists("/.dockerenv")
[docs]def check_numerics(nested):
"""Assert all the tensors in nested are finite.
Args:
nested (nested Tensor): nested Tensor to be checked.
"""
nested_finite = alf.nest.map_structure(
lambda x: torch.all(torch.isfinite(x)), nested)
if not all(alf.nest.flatten(nested_finite)):
bad = alf.nest.map_structure(lambda x, finite: () if finite else x,
nested, nested_finite)
assert all(alf.nest.flatten(nested_finite)), (
"Some tensor in nested is not finite: %s" % bad)
[docs]def get_all_parameters(obj):
"""Get all the parameters under ``obj`` and its descendents.
Note: This function assumes all the parameters can be reached through tuple,
list, dict, set, nn.Module or the attributes of an object. If a parameter is
held in a strange way, it will not be included by this function.
Args:
obj (object): will look for paramters under this object.
Returns:
list: list of (path, Parameters)
"""
all_parameters = []
memo = set()
unprocessed = [(obj, '')]
# BFS for all subobjects
while unprocessed:
obj, path = unprocessed.pop(0)
if isinstance(obj, types.ModuleType):
# Do not traverse into a module. There are too much stuff inside a
# module.
continue
if isinstance(obj, nn.Parameter):
all_parameters.append((path, obj))
continue
if isinstance(obj, torch.Tensor):
continue
if path:
path += '.'
if nest.is_namedtuple(obj):
for name, value in nest.extract_fields_from_nest(obj):
if id(value) not in memo:
unprocessed.append((value, path + str(name)))
memo.add(id(value))
elif isinstance(obj, dict):
# The keys of a generic dict are not necessarily str, and cannot be
# handled by nest.extract_fields_from_nest.
for name, value, in obj.items():
if id(value) not in memo:
unprocessed.append((value, path + str(name)))
memo.add(id(value))
elif isinstance(obj, (tuple, list, set)):
for i, value in enumerate(obj):
if id(value) not in memo:
unprocessed.append((value, path + str(i)))
memo.add(id(value))
elif isinstance(obj, nn.Module):
for name, m in obj.named_children():
if id(m) not in memo:
unprocessed.append((m, path + name))
memo.add(id(m))
for name, p in obj.named_parameters():
if id(p) not in memo:
unprocessed.append((p, path + name))
memo.add(id(p))
attribute_names = dir(obj)
for name in attribute_names:
if name.startswith('__') and name.endswith('__'):
# Ignore system attributes,
continue
attr = None
try:
attr = getattr(obj, name)
except:
# some attrbutes are property function, which may raise exception
# when called in a wrong context (e.g. Algorithm.experience_spec)
pass
if attr is None or id(attr) in memo:
continue
unprocessed.append((attr, path + name))
memo.add(id(attr))
return all_parameters
[docs]def snapshot_repo_roots() -> Dict[str, str]:
"""Return a dict of repo root dirs for snapshot. The paths should be defined
by a special environment variable ``ALF_SNAPSHOT_REPO_ROOTS``, in the following
format:
.. code-block:: bash
export ALF_SNAPSHOT_REPO_ROOTS="<module_name1>=<repo_root1>:<module_name2>=<repo_root2>:..."
where pairs of "<module_name>=<repo_root>" are separated by ":". Note that
``<repo_root>`` should be the parent dir of the module package dir.
Returns:
dict[str]: a dict of ``{module_name: repo_root}``, excluding the alf repo
itself.
"""
repo_roots_envar = os.getenv('ALF_SNAPSHOT_REPO_ROOTS')
repo_roots = {}
if repo_roots_envar is not None:
pairs = repo_roots_envar.split(':')
for p in pairs:
assert '=' in p, (
"Each repo str must be in the format '<module>=<repo_root>'! "
f"Got {p}")
module, repo_root = p.split('=')
repo_roots[module] = str(pathlib.Path(repo_root).absolute())
return repo_roots
[docs]def generate_alf_snapshot(alf_root: str, conf_file: str, dest_path: str):
"""Given a destination path, copy the local ALF root dir to the path. To
save disk space, only ``*.py`` files will be copied.
This function can be used to generate a snapshot of the repo so that the
exactly same code status will be recovered when later playing a trained
model or launching a grid-search job in the waiting queue.
Args:
alf_root: the parent path of the 'alf' module
conf_file: the alf config file
dest_path: the path to generate a snapshot of ALF repo
"""
def _is_subdir(path, directory):
relative = os.path.relpath(path, directory)
return not relative.startswith(os.pardir)
def rsync(src, target, includes):
args = ['rsync', '-rI', '--include=*/']
args += ['--include=%s' % i for i in includes]
args += ['--exclude=*']
args += [src, target]
# shell=True preserves string arguments
subprocess.check_call(
" ".join(args), stdout=sys.stdout, stderr=sys.stdout, shell=True)
includes = [
"*.py", "*.gin", "*.so", "*.json", "*.xml", "*.cpp", "*.c", "*.hpp",
"*.h", "*.stl"
]
repo_roots = {**snapshot_repo_roots(), **{'alf': alf_root}}
for name, root in repo_roots.items():
assert not _is_subdir(dest_path, root), (
"Snapshot path '%s' is not allowed under any repo root '%s'! " %
(dest_path, root) + "Use a different one!")
# Only copy the module dir because the root dir might contain many
# other modules in the case where repo is pip installed in 'site-packages'.
rsync(root + f'/{name}', dest_path, includes)
# compress the snapshot repo into a ".tar.gz" file
os.system(
f"cd {dest_path}; tar -czf {name}.tar.gz {name}; rm -rf {name}")
info(f"Generated a snapshot {name}@{root}")
[docs]def unzip_alf_snapshot(root_dir: str):
"""Restore an ALF snapshot from a job directory by unzipping the snapshot
'tar.gz' files.
Args:
root_dir: the tensorboard job directory
"""
module_names = []
for zipped_repo in glob.glob(f"{root_dir}/*.tar.gz"):
# assuming all '*.tar.gz' under root_dir are repo snapshots
name = os.path.basename(zipped_repo).split('.')[0]
info("=== Using an ALF snapshot at '%s' ===", zipped_repo)
os.system(f"rm -rf {root_dir}/{name}")
os.system(f"cd {root_dir}; tar -xzf {name}.tar.gz")
module_names.append(name)
return module_names
[docs]def get_alf_snapshot_env_vars(root_dir):
"""Given a ``root_dir``, return modified env variable dict so that ``PYTHONPATH``
points to the ALF snapshot under this directory.
"""
module_names = unzip_alf_snapshot(root_dir)
python_path = os.environ.get("PYTHONPATH", "")
for name in module_names:
assert not is_repo_root(os.getcwd(), name), (
"Using a snapshot is not allowed under a valid repo root: " +
"'%s' (contains '%s')!" % (os.getcwd(), name) +
" Try running the command in a different directory.")
root = root_dir
if name == "alf":
legacy_alf_root = os.path.join(root, "alf")
if os.path.isfile(os.path.join(legacy_alf_root, "alf")):
# legacy alf repo path for backward compatibility
# legacy tb dirs: root_dir/alf/alf/__init__.py
root = legacy_alf_root
alf_examples = os.path.join(root, "alf/examples")
python_path = ":".join([root, alf_examples, python_path])
else:
python_path = ":".join([root, python_path])
env_vars = copy.copy(os.environ)
env_vars.update({"PYTHONPATH": python_path})
return env_vars
[docs]def abs_path(path):
"""Given any path, return the absolute path with expanding the user.
"""
return os.path.realpath(os.path.expanduser(path))
_alf_root = None
[docs]def alf_root():
"""Get the ALF root path."""
global _alf_root
if _alf_root is None:
# alf.__file__==<ALF_ROOT>/alf/__init__.py
_alf_root = str(pathlib.Path(alf.__file__).parent.parent.absolute())
return _alf_root
[docs]def is_repo_root(dir, module_name):
"""Given a directory, check if it is a valid repo root. Currently the way
of checking is to see if there is valid ``__init__.py`` under it.
"""
return os.path.isfile(os.path.join(dir, f'{module_name}/__init__.py'))
[docs]def compute_summary_or_eval_interval(config, summary_or_eval_calls=100):
"""Automatically compute a summary or eval interval according to the config
and the expected total number of summary or eval calls. This function can
avoid manually computing the interval value when an expected number of calls
is in mind.
.. warning::
This function might not work for algorithms that change the global
counter themselves, e.g., ``LMAlgorithm``.
Args:
config (TrainerConfig): the configuration object for training
summary_or_eval_calls (int): the expected number of summary
or eval calls throughout the training process. This number can control
the time consumed on summary or eval. Note that this number might not
be exactly satisfied eventually, if the calculated interval has been
rounded up.
Returns:
int: summary or eval interval
"""
# Do not support this for now because the summary global counter will have
# a different value with the iteration number.
assert not config.update_counter_every_mini_batch, (
"This function currently doesn't support update_counter_every_mini_batch=True!"
)
if config.num_iterations > 0:
num_iterations = config.num_iterations
# this condition is exclusive with the above
else:
assert config.num_env_steps
# the rollout env is always creatd with ``nonparallel=False``
num_envs = alf.get_config_value(
"create_environment.num_parallel_environments")
num_iterations = config.num_env_steps / (
num_envs * config.unroll_length)
interval = math.ceil(num_iterations / summary_or_eval_calls)
info_once("A summary or eval interval=%d is calculated" % interval)
return interval
[docs]def call_stack() -> List[str]:
"""Return a list of strings showing the current function call stacks for
debugging.
"""
return [line.strip() for line in traceback.format_stack()]
[docs]@contextlib.contextmanager
def get_unused_port(start, end=65536, n=1):
"""Get an unused port in the range [start, end) .
Args:
start (int) : port range start
end (int): port range end
n (int): get ``n`` consecutive unused ports
Raises:
socket.error: if no unused port is available
"""
process_locks = []
unused_ports = []
try:
for port in range(start, end):
process_locks.append(
InterProcessLock(path='/tmp/socialbot/{}.lock'.format(port)))
if not process_locks[-1].acquire(blocking=False):
process_locks[-1].lockfile.close()
process_locks.pop()
for process_lock in process_locks:
process_lock.release()
process_locks = []
continue
try:
with contextlib.closing(socket.socket()) as sock:
sock.bind(('', port))
unused_ports.append(port)
if len(unused_ports) == 2:
break
except socket.error:
for process_lock in process_locks:
process_lock.release()
process_locks = []
if len(unused_ports) < n:
raise socket.error("No unused port in [{}, {})".format(start, end))
if n == 1:
yield unused_ports[0]
else:
yield unused_ports
finally:
if process_locks:
for process_lock in process_locks:
process_lock.release()