Source code for alf.bin.play

# Copyright (c) 2019 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Play a trained model.

You can visualize playing of the trained model by running:

.. code-block:: bash

    cd ${PROJECT}/alf/examples;
    python -m alf.bin.play \
    --root_dir=~/tmp/cart_pole \
    --alsologtostderr

"""

from absl import app
from absl import flags
from absl import logging
import copy
import os
import subprocess
import sys
import torch

from alf.algorithms.data_transformer import create_data_transformer
from alf.environments.utils import create_environment
from alf.trainers import policy_trainer
from alf.utils import common
import alf.summary.render as render
import alf.utils.external_configurables


def _define_flags():
    flags.DEFINE_string(
        'root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'),
        'Root directory for writing logs/summaries/checkpoints.')
    flags.DEFINE_integer(
        'checkpoint_step', None,
        "the number of training steps which is used to "
        "specify the checkpoint to be loaded. If None, the latest checkpoint under "
        "train_dir will be used.")
    flags.DEFINE_integer('random_seed', None, "random seed")
    flags.DEFINE_bool(
        'force_torch_deterministic', True,
        'torch.use_deterministic_algorithms when random_seed is set. '
        'When it is False, deterministic behavior is not guaranteed, '
        'but could still be deterministic, e.g. for sac_breakout_conf.py. '
        'Setting a random seed without setting this to False, training '
        'could throw this error: _scatter_add kernel does not have a '
        'deterministic implementation.')
    flags.DEFINE_integer('num_episodes', 10, "number of episodes to play")
    flags.DEFINE_integer(
        'last_step_repeats', 0,
        "If >0, wil repeat such number of times for the last "
        "frame of each episode in the rendered video file.")
    flags.DEFINE_integer(
        'append_blank_frames', 0,
        "If >0, wil append such number of blank frames at the "
        "end of each episode in the rendered video file.")
    flags.DEFINE_float('sleep_time_per_step', 0.01,
                       "sleep so many seconds for each step")
    flags.DEFINE_string(
        'record_file', None, "If provided, video will be recorded"
        "to a file instead of shown on the screen.")
    # use '--norender' to disable frame rendering
    flags.DEFINE_bool(
        'render', True,
        "Whether render ('human'|'rgb_array') the frames or not")
    # use '--alg_render' to enable algorithm specific rendering
    flags.DEFINE_bool('alg_render', False,
                      "Whether enable algorithm specific rendering")
    flags.DEFINE_string('gin_file', None, 'Path to the gin-config file.')
    flags.DEFINE_multi_string('gin_param', None, 'Gin binding parameters.')
    flags.DEFINE_string('conf', None, 'Path to the alf config file.')
    flags.DEFINE_multi_string('conf_param', None, 'Config binding parameters.')
    flags.DEFINE_string(
        'ignored_parameter_prefixes', "",
        "Comma separated strings to ingore the parameters whose name has one of "
        "these prefixes in the checkpoint.")
    flags.DEFINE_bool(
        'use_alf_snapshot', False,
        'Whether to use ALF snapshot stored in the model dir (if any). You can set '
        'this flag to play a model trained with legacy ALF code.')
    flags.DEFINE_integer('parallel_play', 1,
                         'Play so many simulations simultaneously')

    flags.DEFINE_bool(
        'selective_mode', False, "Whether use the selective mode. "
        "If True, only save the discoverted selective cases within"
        "the `num_episodes` of test episode. This mode "
        "should be used together with the video recording mode.")


FLAGS = flags.FLAGS


[docs]def play(): if torch.cuda.is_available(): alf.set_default_device("cuda") render.enable_rendering(FLAGS.alg_render) seed = common.set_random_seed(FLAGS.random_seed) if FLAGS.parallel_play > 1: alf.config( 'create_environment', for_evaluation=True, num_parallel_environments=FLAGS.parallel_play, mutable=False) else: alf.config('create_environment', for_evaluation=True, nonparallel=True) alf.config('TrainerConfig', mutable=False, random_seed=seed) conf_file = common.get_conf_file() assert conf_file is not None, "Conf file not found! Check your root_dir" try: common.parse_conf_file(conf_file) except Exception as e: alf.close_env() raise e if FLAGS.selective_mode: assert FLAGS.record_file is not None, ("Should provide a valid value " "for `record_file`") config = policy_trainer.TrainerConfig(root_dir="") env = alf.get_env() env.reset() data_transformer = create_data_transformer(config.data_transformer_ctor, env.observation_spec()) config.data_transformer = data_transformer # keep compatibility with previous gin based config common.set_global_env(env) observation_spec = data_transformer.transformed_observation_spec common.set_transformed_observation_spec(observation_spec) algorithm_ctor = config.algorithm_ctor algorithm = algorithm_ctor( observation_spec=observation_spec, action_spec=env.action_spec(), reward_spec=env.reward_spec(), config=config) algorithm.set_path('') try: policy_trainer.play( common.abs_path(FLAGS.root_dir), env, algorithm, checkpoint_step=FLAGS.checkpoint_step or "latest", num_episodes=FLAGS.num_episodes, sleep_time_per_step=FLAGS.sleep_time_per_step, record_file=FLAGS.record_file, append_blank_frames=FLAGS.append_blank_frames, last_step_repeats=FLAGS.last_step_repeats, render=FLAGS.render, selective_mode=FLAGS.selective_mode, ignored_parameter_prefixes=FLAGS.ignored_parameter_prefixes.split( ",") if FLAGS.ignored_parameter_prefixes else []) finally: alf.close_env()
[docs]def launch_snapshot_play(): """This play function uses historical ALF snapshot for playing a trained model, consistent with the code snapshot that trains the model. In the newer version of ``train.py``, a ALF snapshot is saved to ``root_dir`` right before the training begins. So this function prepends ``root_dir`` to ``PYTHONPATH`` to allow using the snapshot ALF repo in that place. Note that for any old training ``root_dir`` prior to snapshot being enabled, this function doesn't have any effect and the most up-to-date ALF will be used by play. """ # assert the current path is not ALF_ROOT because sys.path will always prepend # the current path to the path list, which makes our snapshot ALF path shadowed root_dir = common.abs_path(FLAGS.root_dir) env_vars = common.get_alf_snapshot_env_vars(root_dir) flags = sys.argv[1:] flags.append('--nouse_alf_snapshot') args = ['python', '-m', 'alf.bin.play'] + flags try: subprocess.check_call( " ".join(args), env=env_vars, stdout=sys.stdout, stderr=sys.stdout, shell=True) except subprocess.CalledProcessError: # No need to output anything pass
[docs]def main(_): if not FLAGS.use_alf_snapshot: play() else: launch_snapshot_play()
if __name__ == '__main__': _define_flags() flags.mark_flag_as_required('root_dir') logging.set_verbosity(logging.INFO) app.run(main)