Source code for alf.trainers.policy_trainer

# Copyright (c) 2019 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer for training an Algorithm on given environments."""

import abc
from absl import logging
from functools import partial
from typing import Dict
import math
import os
from pathlib import Path
import re
import signal
import threading
import sys
import time
import torch
import torch.nn as nn
from PIL import Image
import numpy as np

import alf
from alf.algorithms.algorithm import Algorithm, Loss
from alf.networks import Network
from alf.algorithms.config import TrainerConfig
from alf.algorithms.data_transformer import (create_data_transformer,
                                             IdentityDataTransformer)
from alf.data_structures import StepType
from alf.environments.utils import create_environment
from alf.nest import map_structure
from alf.tensor_specs import TensorSpec
from alf.utils import common
from alf.utils import git_utils
from alf.utils import math_ops
from alf.utils.pretty_print import pformat_pycolor
from alf.utils.checkpoint_utils import Checkpointer
import alf.utils.datagen as datagen
from alf.utils.per_process_context import PerProcessContext
from alf.utils.summary_utils import record_time
from .evaluator import Evaluator


class _TrainerProgress(nn.Module):
    def __init__(self):
        super(_TrainerProgress, self).__init__()
        self.register_buffer("_iter_num", torch.zeros((), dtype=torch.int64))
        self.register_buffer("_env_steps", torch.zeros((), dtype=torch.int64))
        self._num_iterations = None
        self._num_env_steps = None
        self._progress = None

    def set_termination_criterion(self, num_iterations, num_env_steps=0):
        self._num_iterations = float(num_iterations)
        self._num_env_steps = float(num_env_steps)
        # might be loaded from a checkpoint, so we update first
        self.update()

    def update(self, iter_num=None, env_steps=None):
        if iter_num is not None:
            self._iter_num.fill_(iter_num)
        if env_steps is not None:
            self._env_steps.fill_(env_steps)

        assert not (self._num_iterations is None
                    and self._num_env_steps is None), (
                        "You must first call set_terimination_criterion()!")
        if self._num_iterations > 0:
            self._progress = float(
                self._iter_num.to(torch.float64) / self._num_iterations)
        else:
            self._progress = float(
                self._env_steps.to(torch.float64) / self._num_env_steps)

    def set_progress(self, value: float):
        """Manually set the current progress.

        Args:
            value: a float number in [0, 1]
        """
        self._progress = value

    @property
    def progress(self):
        assert self._progress is not None, "Must call update() first!"
        return self._progress


def _visualize_alf_tree(module: Algorithm):
    """Generate a graphviz graph of the module tree structure.

    This is useful to visualize the hierarchy of the current AFL algorithm.

    Args:
        module: An ALF algorithm.

    Returns:
        A graphviz directed graph that can be rendered as pdf.
    """
    try:
        import graphviz
    except ImportError:
        logging.warn(
            'Need "graphviz" installed if you want to visualize modules')
        return None

    def _is_layer(node):
        class_name = node.__class__.__name__
        return (isinstance(node, nn.Module) and class_name in dir(alf.layers))

    def _visual_style(node: torch.nn.Module) -> Dict[str, str]:
        """Loss: 'gray',
           Algorithm: 'blue',
           Network: 'orange',
           Layer: 'yellow'
        """
        if isinstance(node, Loss):
            return {
                'style': 'filled',
                'fillcolor': '#DCDCDC',
            }
        elif isinstance(node, Algorithm):
            return {
                'style': 'filled',
                'fillcolor': '#00BFFF',
            }
        elif isinstance(node, Network):
            return {
                'style': 'filled',
                'fillcolor': '#FF8C00',
            }
        elif _is_layer(node):
            return {'style': 'filled', 'fillcolor': '#ffdc7d', 'fontsize': '8'}
        return {}

    def _generate_node_label(node):
        """Generate the proper label for a given node.
        """

        def _get_func_name(match_obj):
            """Further extract the function name from the <...> representation.

            For example, if the match_obj corresponds to a string like below:

                <built-in method relu_ of type object at 0x7ff7a790f620>

            This function extracts "relu_" out of it.
            """
            # Such representation can start with either "bound method",
            # "built-in method" or "function".
            res = re.match(
                r'<(bound method|built-in method|function) (\S+) .*>',
                match_obj.group())
            if res is None:
                # In case there is an outlier, return "NOT_PARSED" instead.
                return 'NOT_PARSED'
            if len(res.group(2)) > 10:
                # Shorten the function name if it is very long.
                return f'{res.group(2)[:10]}...'
            return res.group(2)

        if _is_layer(node):
            # We need to parse function repr with pattern <... at 0x???> because
            # graphviz doesn't support '<' or '>' in the label
            return re.sub("<[^<]*>", _get_func_name, repr(node))
        else:
            return getattr(node, "name", type(node).__name__)

    def _filter_child(field, child):
        """A set of rules to filter out certain components in the rendered graph.
        """
        conditions = [
            # Every Algorithm will contain a default identity transformer.
            (field == "_data_transformer"
             and isinstance(child, IdentityDataTransformer)),
        ]
        return any(conditions)

    dot = graphviz.Digraph()
    dot.attr('node', shape='record')
    dot.graph_attr['rankdir'] = 'LR'

    def _visit(node, idx, visited):
        """Visit a node by depth-first search. For each algorithm node, we create
        a subgraph that encloses all its children.
        """
        idx[0] += 1
        node_index = idx[0]
        visited[node] = node_index
        label = _generate_node_label(node)
        node_records = ["<caption> " + label + f"(id={node_index})"]
        edges = []

        for field, child in node.named_children():
            if _filter_child(field, child):
                continue
            if child not in visited:
                edges += _visit(child, idx, visited)
            child_idx = visited[child]
            node_records.append(f'<{field}> ({field})')
            edge = (f'{node_index}:{field}', f'{child_idx}:caption')
            edges.append(edge)

        dot.node(
            str(node_index),
            label='|'.join(node_records),
            **_visual_style(node))

        if isinstance(node, Algorithm):
            # NOTE: the subgraph name needs to begin with 'cluster' (all lowercase)
            #       so that Graphviz recognizes it as a special cluster subgraph
            with dot.subgraph(name=f'cluster_{node_index}') as c:
                c.attr(color='green')
                if node_index != 0:
                    # Do not draw duplicate edges for subgraphs
                    c.edge_attr['style'] = 'invis'
                c.edges(edges)
                c.attr(label=label)

        return edges

    _visit(module, idx=[-1], visited={})

    return dot


[docs]class Trainer(object): """Base class for trainers. Trainer is responsible for creating algorithm and dataset/environment, setting up summary, checkpointing, running training iterations, and evaluating periodically. """ _trainer_progress = _TrainerProgress() def __init__(self, config: TrainerConfig, ddp_rank: int = -1): """ Args: config: configuration used to construct this trainer ddp_rank: process (and also device) ID of the process, if the process participates in a DDP process group to run distributed data parallel training. A value of -1 indicates regular single process training. """ Trainer._trainer_progress = _TrainerProgress() root_dir = config.root_dir self._root_dir = root_dir self._train_dir = os.path.join(root_dir, 'train') self._eval_dir = os.path.join(root_dir, 'eval') self._algorithm_ctor = config.algorithm_ctor self._algorithm = None self._num_checkpoints = config.num_checkpoints self._checkpointer = None self._evaluate = config.evaluate self._eval_uncertainty = config.eval_uncertainty if config.num_evals is not None: self._eval_interval = common.compute_summary_or_eval_interval( config, config.num_evals) else: self._eval_interval = config.eval_interval if config.num_summaries is not None: self._summary_interval = common.compute_summary_or_eval_interval( config, config.num_summaries) else: self._summary_interval = config.summary_interval self._summaries_flush_secs = config.summaries_flush_secs self._summary_max_queue = config.summary_max_queue self._debug_summaries = config.debug_summaries self._summarize_grads_and_vars = config.summarize_grads_and_vars self._config = config self._random_seed = config.random_seed self._rank = ddp_rank self._pid = None
[docs] def train(self): """Perform training.""" self._restore_checkpoint() alf.summary.enable_summary() if self._pid is None: self._pid = os.getpid() self._checkpoint_requested = False if threading.current_thread() == threading.main_thread(): signal.signal(signal.SIGUSR2, self._request_checkpoint) # kill -12 PID logging.info( "Use `kill -%s %s` to request checkpoint during training." % (int(signal.SIGUSR2), self._pid)) self._debug_requested = False if threading.current_thread() == threading.main_thread(): # kill -10 PID signal.signal(signal.SIGUSR1, self._request_debug) logging.info("Use `kill -%s %s` to request debugging." % (int( signal.SIGUSR1), self._pid)) checkpoint_saved = False try: if self._config.profiling: import cProfile, pstats, io pr = cProfile.Profile() pr.enable() common.run_under_record_context( self._train, summary_dir=self._train_dir, summary_interval=self._summary_interval, summarize_first_interval=self._config.summarize_first_interval, flush_secs=self._summaries_flush_secs, summary_max_queue=self._summary_max_queue) if self._config.profiling: pr.disable() s = io.StringIO() ps = pstats.Stats(pr, stream=s).sort_stats('time') ps.print_stats() ps = pstats.Stats(pr, stream=s).sort_stats('cumulative') ps.print_stats() ps.print_callees() logging.info(s.getvalue()) self._save_checkpoint() checkpoint_saved = True finally: if (self._config.confirm_checkpoint_upon_crash and not checkpoint_saved and self._rank <= 0): # Prompts for checkpoint only when running single process # training (rank is -1) or master process of DDP training (rank # is 0). ans = input("Do you want to save checkpoint? (y/n): ") if ans.lower().startswith('y'): self._save_checkpoint() self._close()
[docs] @staticmethod def progress(): """A static method that returns the current training progress, provided that only one trainer will be used for training. Returns: float: a number in :math:`[0,1]` indicating the training progress. """ return Trainer._trainer_progress.progress
[docs] @staticmethod def current_iterations(): return Trainer._trainer_progress._iter_num
[docs] @staticmethod def current_env_steps(): return Trainer._trainer_progress._env_steps
def _train(self): """Perform training according the the learning type. """ pass def _close(self): """Closing operations after training. """ pass def _summarize_training_setting(self): # We need to wait for one iteration to get the operative args # Right just give a fixed gin file name to store operative args common.write_config(self._root_dir) with alf.summary.record_if(lambda: True): def _markdownify(paragraph): return " ".join( (os.linesep + paragraph).splitlines(keepends=True)) common.summarize_config() alf.summary.text('commandline', ' '.join(sys.argv)) alf.summary.text( 'optimizers', _markdownify(self._algorithm.get_optimizer_info())) alf.summary.text( 'unoptimized_parameters', _markdownify(self._algorithm.get_unoptimized_parameter_info())) repo_roots = { **common.snapshot_repo_roots(), **{ 'alf': common.alf_root() } } for name, root in repo_roots.items(): alf.summary.text(f'{name}/revision', git_utils.get_revision(f'{root}/{name}')) alf.summary.text( f'{name}/diff', _markdownify(git_utils.get_diff(f'{root}/{name}'))) alf.summary.text('seed', str(self._random_seed)) # Save a rendered directed graph of the algorithm to the root # directory. algorithm_structure_graph = _visualize_alf_tree(self._algorithm) if algorithm_structure_graph is not None: import graphviz try: algorithm_structure_graph.render( Path(self._root_dir, 'algorithm_sturcture'), format='png', quiet=True) except graphviz.backend.CalledProcessError as e: # graphviz will treat any warning in the rendering as error # and panic. We should just warn instead. logging.warn(f'Graphviz rendering: {str(e)}') image_path = Path(self._root_dir, 'algorithm_sturcture.png') if image_path.exists(): img = np.array(Image.open(image_path)) alf.summary.images( 'algorithm_structure', img, dataformat='HWC', step=0) if self._config.code_snapshots is not None: for f in self._config.code_snapshots: path = os.path.join( os.path.abspath(os.path.dirname(__file__)), "..", f) if not os.path.isfile(path): common.warning_once( "The code file '%s' for summary is invalid" % path) continue with open(path, 'r') as fin: code = fin.read() # adding "<pre>" will make TB show raw text instead of MD alf.summary.text('code/%s' % f, "<pre>" + code + "</pre>") def _request_checkpoint(self, signum, frame): self._checkpoint_requested = True def _request_debug(self, signum, frame): self._debug_requested = True def _save_checkpoint(self): # Saving checkpoint is only enabled when running single process training # (rank is -1) or master process of DDP training (rank is 0). if self._rank <= 0: global_step = alf.summary.get_global_counter() self._checkpointer.save(global_step=global_step) def _restore_checkpoint(self, checkpointer): """Retore from saved checkpoint. Args: checkpointer (Checkpointer): """ if checkpointer.has_checkpoint(): # Some objects (e.g. ReplayBuffer) are constructed lazily in algorithm. # They only appear after one training iteration. So we need to run # train_iter() once before loading the checkpoint self._algorithm.train_iter() try: recovered_global_step = checkpointer.load() self._trainer_progress.update() except RuntimeError as e: raise RuntimeError( ("Checkpoint loading failed from the provided root_dir={}. " "Typically this is caused by using a wrong checkpoint. \n" "Please make sure the root_dir is set correctly. " "Use a new value for it if " "planning to train from scratch. \n" "Detailed error message: {}").format(self._root_dir, e)) if recovered_global_step != -1: alf.summary.set_global_counter(recovered_global_step) self._checkpointer = checkpointer
[docs]class RLTrainer(Trainer): """Trainer for reinforcement learning. """ def __init__(self, config: TrainerConfig, ddp_rank: int = -1): """ Args: config (TrainerConfig): configuration used to construct this trainer ddp_rank (int): process (and also device) ID of the process, if the process participates in a DDP process group to run distributed data parallel training. A value of -1 indicates regular single process training. """ super().__init__(config, ddp_rank) self._num_env_steps = config.num_env_steps self._num_iterations = config.num_iterations assert self._num_iterations + self._num_env_steps > 0, \ "Must provide #iterations or #env_steps for training!" if self._num_iterations > 0 and self._num_env_steps > 0: num_envs = alf.get_config_value( "create_environment.num_parallel_environments") num_iterations_with_env_interations = config.num_env_steps / ( num_envs * config.unroll_length) pure_train_iters = self._num_iterations - num_iterations_with_env_interations assert pure_train_iters >= 0, ( f"num_iterations={self._num_iterations} is not enough for " f"num_env_steps={self._num_env_steps}") logging.info("There is no environmental interation in the last" f"{pure_train_iters} iterations") self._trainer_progress.set_termination_criterion( self._num_iterations, self._num_env_steps) self._num_eval_episodes = config.num_eval_episodes alf.summary.should_summarize_output(config.summarize_output) env = alf.get_env() logging.info( "observation_spec=\n%s" % pformat_pycolor(env.observation_spec())), logging.info("action_spec=\n%s" % pformat_pycolor(env.action_spec())) # for offline buffer construction untransformed_observation_spec = env.observation_spec() data_transformer = create_data_transformer( config.data_transformer_ctor, untransformed_observation_spec) self._config.data_transformer = data_transformer # keep compatibility with previous gin based config common.set_global_env(env) observation_spec = data_transformer.transformed_observation_spec common.set_transformed_observation_spec(observation_spec) logging.info("transformed_observation_spec=%s" % pformat_pycolor(observation_spec)) self._algorithm = self._algorithm_ctor( observation_spec=observation_spec, action_spec=env.action_spec(), reward_spec=env.reward_spec(), env=env, config=self._config, debug_summaries=self._debug_summaries) # recover offline buffer self._algorithm.load_offline_replay_buffer( untransformed_observation_spec) self._algorithm.set_path('') if ddp_rank >= 0: # Activate the DDP training self._algorithm.activate_ddp(ddp_rank) # Make sure the BN statistics of different processes are synced # https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html#torch.nn.SyncBatchNorm # This conversion needs to be performed before wrapping modules with DDP. self._algorithm = torch.nn.SyncBatchNorm.convert_sync_batchnorm( self._algorithm) # Create a thread env to expose subprocess gin/alf configurations # which otherwise will be marked as "inoperative". Only created when # ``TrainerConfig.no_thread_env_for_conf=False``. self._thread_env = None def _env_in_subprocess(e): if isinstance( e, alf.environments.alf_wrappers.AlfEnvironmentBaseWrapper): return _env_in_subprocess(e.wrapped_env()) # TODO: One special case is alf_wrappers.MultitaskWrapper which is # an alf wrapper but not a subclass of AlfEnvironmentBaseWrapper. # Its env members might be in the main process or might not. return isinstance( e, (alf.environments.parallel_environment.ParallelAlfEnvironment, alf.environments.fast_parallel_environment. FastParallelEnvironment)) # See ``alf/docs/notes/knowledge_base.rst``` # (ParallelAlfEnvironment and ThreadEnvironment) for details. if not config.no_thread_env_for_conf and _env_in_subprocess(env): self._thread_env = create_environment( nonparallel=True, seed=self._random_seed) if self._evaluate: self._evaluator = Evaluator(self._config, common.get_conf_file()) def _close_envs(self): """Close all envs to release their resources.""" alf.close_env() if self._thread_env is not None: self._thread_env.close() def _train(self): env = alf.get_env() env.reset() iter_num = int(self._trainer_progress._iter_num) training_setting_summarized = False checkpoint_interval = math.ceil( (self._num_iterations or self._num_env_steps) / self._num_checkpoints) if self._num_iterations: time_to_checkpoint = self._trainer_progress._iter_num + checkpoint_interval else: time_to_checkpoint = self._trainer_progress._env_steps + checkpoint_interval if self._evaluate and iter_num == 0: self._eval() while True: t0 = time.time() with record_time("time/train_iter"): train_steps = self._algorithm.train_iter() t = time.time() - t0 logging.log_every_n_seconds( logging.INFO, '%s [pid: %s] %s -> %s: %s time=%.3f throughput=%0.2f' % ('' if self._rank == -1 else f'[rank {self._rank:02d}] ', self._pid, common.get_conf_file(), os.path.basename(self._root_dir.strip('/')), iter_num, t, int(train_steps) / t), n_seconds=1) just_evaluated = False if self._evaluate and (iter_num + 1) % self._eval_interval == 0: if (self._config.num_evals is None or (iter_num + 1) // self._eval_interval < self._config.num_evals): # If num_evals is specified, the last evaluation will be # performed after training finishes. self._eval() just_evaluated = True if not training_setting_summarized and train_steps > 0: self._summarize_training_setting() training_setting_summarized = True # check termination env_steps_metric = self._algorithm.get_step_metrics()[1] total_time_steps = env_steps_metric.result() iter_num += 1 self._trainer_progress.update(iter_num, total_time_steps) if ((self._num_iterations and iter_num >= self._num_iterations) or (not self._num_iterations and total_time_steps >= self._num_env_steps)): # Evaluate before exiting so that the eval curve shown in TB # will align with the final iter/env_step. if self._evaluate and not just_evaluated: self._eval() break self._check_dpp_paras_consistency(iter_num, training_setting_summarized) if ((self._num_iterations and iter_num >= time_to_checkpoint) or (not self._num_iterations and self._num_env_steps and total_time_steps >= time_to_checkpoint)): self._save_checkpoint() time_to_checkpoint += checkpoint_interval elif self._checkpoint_requested: logging.info("Saving checkpoint upon request...") self._save_checkpoint() self._checkpoint_requested = False if self._debug_requested: self._debug_requested = False import pdb pdb.set_trace() def _check_dpp_paras_consistency(self, iter_num: int, training_started: bool): """Periodically check the consistency of model parameters of different DDP processes. Note that DDP can only make sure that the parameters requiring gradients are always consistent across processes, but cannot guarantee the same thing for those without gradients. An example scenario is the target model of SAC. Here we check the parameters for both cases. If any inconsistency is found, a warning message will be printed, without interrupting the training. Even if all model parameters require gradients, this function can still serve as a sanity checker for our DDP implementation. Args: iter_num: current training iteration training_started: if training has started or not. ALF only wraps DDP around algorithms after first training forward, when model parameters are first synced. So only consistency checking after this is meaingful. Note that for off-policy training before first gradient update, different processes might use different models to unroll. """ if not training_started: return proc_cxt = PerProcessContext() if not (proc_cxt.is_distributed and self._config.ddp_paras_check_interval > 0 # Assume that DDP will make sure that this modulo check won't # cause a dead lock, i.e., all workers have the same ``iter_num`` # at any moment. and iter_num % self._config.ddp_paras_check_interval == 0): return with alf.summary.record_if(lambda: True): with record_time("time/para_consistency"): paras_stat = self._algorithm.compute_paras_statistics() queue = proc_cxt.paras_queue if self._rank > 0: # Put para stat into the queue. Don't need to wait rank=0's # return, because DDP will sync processes at the next # gradient update. queue.put(paras_stat) else: consistent = True # rank=0 gets all other para stats for i in range(proc_cxt.num_processes - 1): their_paras_stat = queue.get() is_close = map_structure( partial(np.isclose, atol=1e-6), paras_stat, their_paras_stat) for k, v in is_close.items(): if not np.all(v): consistent = False common.warning( "Found inconsistent parameter '%s' across " "DDP processes: %s vs. %s" % (k, paras_stat[k], their_paras_stat[k])) if not consistent: common.warning( "Your model parameters are not consistent across" " DDP processes. Please make sure to check if there" " is any computation that relies on local-batch " "statistics in the algorithm.") else: common.info("Model parameters are consistent") alf.summary.scalar("DDP/para_consistency", torch.tensor(float(consistent))) def _close(self): """Closing operations after training. """ self._algorithm.finish_train() self._close_envs() if self._evaluate: self._evaluator.close() def _restore_checkpoint(self): checkpointer = Checkpointer( ckpt_dir=os.path.join(self._train_dir, 'algorithm'), algorithm=self._algorithm, metrics=nn.ModuleList(self._algorithm.get_metrics()), trainer_progress=self._trainer_progress) super()._restore_checkpoint(checkpointer) def _eval(self): step_metrics = self._algorithm.get_step_metrics() step_metrics = dict((m.name, int(m.result())) for m in step_metrics) self._evaluator.eval(self._algorithm, step_metrics)
[docs]class SLTrainer(Trainer): """Trainer for supervised learning. """ def __init__(self, config: TrainerConfig): """Create a SLTrainer Args: config (TrainerConfig): configuration used to construct this trainer """ super().__init__(config) assert config.num_iterations > 0, \ "Must provide num_iterations for training!" self._num_epochs = config.num_iterations self._trainer_progress.set_termination_criterion(self._num_epochs) self._algorithm = config.algorithm_ctor(config=config) self._algorithm.set_path('') def _train(self): begin_epoch_num = int(self._trainer_progress._iter_num) epoch_num = begin_epoch_num checkpoint_interval = math.ceil( self._num_epochs / self._num_checkpoints) time_to_checkpoint = begin_epoch_num + checkpoint_interval logging.info("==> Begin Training") while True: t0 = time.time() with record_time("time/train_iter"): train_steps = self._algorithm.train_iter() train_steps = train_steps or 1 t = time.time() - t0 logging.log_every_n_seconds( logging.INFO, '%s -> %s: %s time=%.3f throughput=%0.2f' % (common.get_conf_file(), os.path.basename(self._root_dir.strip('/')), epoch_num, t, int(train_steps) / t), n_seconds=1) if (epoch_num + 1) % self._eval_interval == 0: if self._evaluate: self._algorithm.evaluate() if self._eval_uncertainty: self._algorithm.eval_uncertainty() if epoch_num == begin_epoch_num: self._summarize_training_setting() # check termination epoch_num += 1 self._trainer_progress.update(epoch_num) if (self._num_epochs and epoch_num >= self._num_epochs): if self._evaluate: self._algorithm.evaluate() if self._eval_uncertainty: self._algorithm.eval_uncertainty() break if self._num_epochs and epoch_num >= time_to_checkpoint: self._save_checkpoint() time_to_checkpoint += checkpoint_interval elif self._checkpoint_requested: logging.info("Saving checkpoint upon request...") self._save_checkpoint() self._checkpoint_requested = False if self._debug_requested: self._debug_requested = False import pdb pdb.set_trace() def _restore_checkpoint(self): checkpointer = Checkpointer( ckpt_dir=os.path.join(self._train_dir, 'algorithm'), algorithm=self._algorithm, trainer_progress=self._trainer_progress) super()._restore_checkpoint(checkpointer)
@torch.no_grad() def _step(algorithm, env, time_step, policy_state, trans_state, metrics, render=False, recorder=None, sleep_time_per_step=0, selective_criteria_func=None): """Perform one step interaction using the outpupt action from ``algorithm`` taking ``time_step`` as input. Also record the metrics. Note that this function is used both in ``play`` below and ``evaluate`` in ``evaluator.py``. Args: algorithm (RLAlgorithm): the algorithm under evaluation env: the environment time_step (TimeStep): current time step policy_state (nested Tensor): state of the policy trans_state (nested Tensor): state of the transformer(s) metrics (StepMetric): a list of metrics that will be updated based on ``time_step``. render (bool|False): if True, display the frames of ``env`` on a screen. recorder (VideoRecorder|None): recorder the frames of ``env`` and other additional images in prediction step info if present. sleep_time_per_step (int|0): The sleep time between two frames when ``render`` is True. selective_criteria_func (callable|None): a callable for determining whether an episode will be saved to the video file when a valid recorder is provided. This function takes two input arguments: - return (float): return of the current episode. This is useful for implementing return based selective criteria. - env_info (dict): a dictionary containing information returned by the environment. This is useful for implementing task specific selective criteria using information contained ``env_info``, e.g., success, infraction etc. Returns: - next time step (TimeStep): the next time step after taking an action in ``env`` - policy step (AlgStep): the output from ``algorithm.predict_step`` - new state of the transformer(s) (nested Tensor) """ for metric in metrics: metric(time_step.cpu()) policy_state = common.reset_state_if_necessary( policy_state, algorithm.get_initial_predict_state(env.batch_size), time_step.is_first()) transformed_time_step, trans_state = algorithm.transform_timestep( time_step, trans_state) policy_step = algorithm.predict_step(transformed_time_step, policy_state) if recorder and selective_criteria_func is None: recorder.capture_frame(policy_step.info, time_step.is_last()) elif recorder and selective_criteria_func is not None: env_frame = recorder.capture_env_frame() recorder.cache_frame_and_pred_info(env_frame, policy_step.info) if time_step.is_last(): if selective_criteria_func( map_structure(lambda x: x.cpu().numpy(), metrics[1].latest()), map_structure(lambda x: x.cpu().numpy(), metrics[3].latest())): logging.info( "+++++++++ Selective Case Discovered! +++++++++++") recorder.generate_video_from_cache() else: recorder.clear_cache() elif render: if env.batch_size > 1: env.envs[0].render(mode='human') else: env.render(mode='human') time.sleep(sleep_time_per_step) next_time_step = env.step(policy_step.output) return next_time_step, policy_step, trans_state @common.mark_eval def play(root_dir, env, algorithm, checkpoint_step="latest", num_episodes=10, sleep_time_per_step=0.01, record_file=None, last_step_repeats=0, append_blank_frames=0, render=True, selective_mode=False, ignored_parameter_prefixes=[]): """Play using the latest checkpoint under `train_dir`. The following example record the play of a trained model to a mp4 video: .. code-block:: bash python -m alf.bin.play \ --root_dir=~/tmp/bullet_humanoid/ppo2/ppo2-11 \ --num_episodes=1 \ --record_file=ppo_bullet_humanoid.mp4 Args: root_dir (str): same as the root_dir used for `train()` env (AlfEnvironment): the environment algorithm (RLAlgorithm): the training algorithm checkpoint_step (int|str): the number of training steps which is used to specify the checkpoint to be loaded. If checkpoint_step is 'latest', the most recent checkpoint named 'latest' will be loaded. num_episodes (int): number of episodes to play sleep_time_per_step (float): sleep so many seconds for each step record_file (str): if provided, video will be recorded to a file instead of shown on the screen. last_step_repeats (int): repeat such number of times for the last frame of each episode. append_blank_frames (int): If >0, wil append such number of blank frames at the end of the episode in the rendered video file. A negative value has the same effects as 0 and no blank frames will be appended. This option has no effects when displaying the frames on the screen instead of recording to a file. render (bool): If False, then this function only evaluates the trained model without calling rendering functions. This value will be ignored if a ``record_file`` argument is provided. selective_mode (bool): whether to save the selective cases discovered according to a ``selective_criteria_func``. ignored_parameter_prefixes (list[str]): ignore the parameters whose name has one of these prefixes in the checkpoint. """ train_dir = os.path.join(root_dir, 'train') ckpt_dir = os.path.join(train_dir, 'algorithm') checkpointer = Checkpointer( ckpt_dir=ckpt_dir, algorithm=algorithm, trainer_progress=Trainer._trainer_progress) recovered_global_step = checkpointer.load( checkpoint_step, ignored_parameter_prefixes=ignored_parameter_prefixes, including_optimizer=False, including_replay_buffer=False, including_data_transformers=True, strict=True) # The behavior of some algorithms is based by scheduler using training # progress or global step. So we need to set a valid value for progress # and global step if recovered_global_step != -1: alf.summary.set_global_counter(recovered_global_step) Trainer._trainer_progress.set_termination_criterion( alf.get_config_value('TrainerConfig.num_iterations'), alf.get_config_value('TrainerConfig.num_env_steps')) Trainer._trainer_progress.update() logging.info("global_step=%s TrainerProgress=%s" % (recovered_global_step, Trainer.progress())) batch_size = env.batch_size recorder = None if record_file is not None: assert batch_size == 1, 'video recording is not supported for parallel play' # Note that ``VideoRecorder`` will import ``matplotlib`` which might have # some side effects on xserver (if its backend needs graphics). # This is incompatible with RLBench parallel envs >1 (or other # envs requiring xserver) for some unknown reasons, so we have a lazy import here. from alf.utils.video_recorder import VideoRecorder recorder = VideoRecorder( env, last_step_repeats=last_step_repeats, append_blank_frames=append_blank_frames, path=record_file) elif render: if batch_size > 1: env.envs[0].render(mode='human') else: # pybullet_envs need to render() before reset() to enable mode='human' env.render(mode='human') env.reset() time_step = common.get_initial_time_step(env) algorithm.eval() policy_state = algorithm.get_initial_predict_state(env.batch_size) trans_state = algorithm.get_initial_transform_state(env.batch_size) episodes_per_env = (num_episodes + batch_size - 1) // batch_size env_episodes = torch.zeros(batch_size, dtype=torch.int32) episode_reward = torch.zeros(batch_size) episode_length = torch.zeros(batch_size, dtype=torch.int32) episodes = 0 metrics = [ alf.metrics.NumberOfEpisodes(), alf.metrics.AverageReturnMetric( buffer_size=num_episodes, example_time_step=time_step), alf.metrics.AverageEpisodeLengthMetric( example_time_step=time_step, buffer_size=num_episodes), alf.metrics.AverageEnvInfoMetric( example_time_step=time_step, buffer_size=num_episodes), alf.metrics.AverageDiscountedReturnMetric( buffer_size=num_episodes, example_time_step=time_step) ] if selective_mode: # Below is an example selective criteria based on return. # This should be adjusted according to the particular task. selective_criteria_func = lambda return_value, env_info: return_value < 500 else: selective_criteria_func = None while episodes < num_episodes: # For parallel play, we cannot naively pick the first finished `num_episodes` # episodes to estimate the average return (or other statitics) as it can be # biased. Instead, we stick to using the first episodes_per_env episodes # from each environment to calculate the statistics and ignore the potentially # extra episodes from each environment. invalid = env_episodes >= episodes_per_env # Force the step_type of the extra episodes to be StepType.FIRST so that # these time steps do not affect metrics as the metrics are only updated # at StepType.LAST. The metric computation uses cpu version of time_step. time_step.cpu().step_type[invalid] = StepType.FIRST next_time_step, policy_step, trans_state = _step( algorithm=algorithm, env=env, time_step=time_step, policy_state=policy_state, trans_state=trans_state, metrics=metrics, render=render, recorder=recorder, sleep_time_per_step=sleep_time_per_step, selective_criteria_func=selective_criteria_func) time_step.step_type[invalid] = StepType.FIRST started = time_step.step_type != StepType.FIRST episode_length += started episode_reward += started * time_step.reward.sum() for i in range(batch_size): if time_step.step_type[i] == StepType.LAST: logging.info( "episode_length=%s episode_reward=%s" % (episode_length[i].item(), episode_reward[i].item())) episode_reward[i] = 0. episode_length[i] = 0 env_episodes[i] += 1 episodes += 1 common.log_metrics(metrics) policy_state = policy_step.state time_step = next_time_step env.reset() if recorder: recorder.close()