Source code for alf.environments.process_environment

# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Step a single env in a separate process for lock free paralellism.

Adapted from TF-Agents Environment API as seen in:
    https://github.com/tensorflow/agents/blob/master/tf_agents/environments/parallel_py_environment.py
"""

from absl import logging
import atexit
from enum import Enum
from functools import partial
import multiprocessing
import numpy as np
import sys
import torch
import traceback

import alf
from alf.data_structures import TimeStep
import alf.nest as nest
from . import _penv


class _MessageType(Enum):
    """Message types for communication via the pipe.

    The ProcessEnvironment uses pipe to perform IPC, where each of the message
    has a message type. This Enum provides all the available message types.
    """
    READY = 1
    ACCESS = 2
    CALL = 3
    RESULT = 4
    EXCEPTION = 5
    CLOSE = 6


def _worker(conn,
            env_constructor,
            env_id=None,
            flatten=False,
            fast=False,
            num_envs=0,
            name=''):
    """The process waits for actions and sends back environment results.

    Args:
        conn (multiprocessing.connection): Connection for communication to the main process.
        env_constructor (Callable): callable environment creator.
        flatten (bool): whether to assume flattened actions and time_steps
          during communication to avoid overhead.
        fast (bool): whether created by ``FastParallelEnvironment`` or not.
        num_envs (int): number of environments in the ``FastParallelEnvironment``.
            Only used if ``fast`` is True.
        name (str): name of the FastParallelEnvironment. Only used if ``fast``
            is True.

    Raises:
        KeyError: When receiving a message of unknown type.
    """
    try:
        alf.set_default_device("cpu")
        env = env_constructor(env_id=env_id)
        action_spec = env.action_spec()
        if fast:
            penv = _penv.ProcessEnvironment(
                env, partial(process_call, conn, env, flatten,
                             action_spec), env_id, num_envs, env.batch_size,
                env.action_spec(),
                env.time_step_spec()._replace(env_info=env.env_info_spec()),
                name)
            conn.send(_MessageType.READY)  # Ready.
            try:
                penv.worker()
            except KeyboardInterrupt:
                penv.quit()
            except Exception:
                traceback.print_exc()
                penv.quit()
        else:
            conn.send(_MessageType.READY)  # Ready.
            while True:
                if not process_call(conn, env, flatten, action_spec):
                    break
    except KeyboardInterrupt:
        # When worker receives interruption from keyboard (i.e. Ctrl-C), notify
        # the parent process to shut down quietly by sending the CLOSE message.
        #
        # This is to avoid sometimes tens of environment processes panicking
        # simultaneously.
        conn.send((_MessageType.CLOSE, None))
    except Exception:  # pylint: disable=broad-except
        etype, evalue, tb = sys.exc_info()
        stacktrace = ''.join(traceback.format_exception(etype, evalue, tb))
        message = 'Error in environment process: {}'.format(stacktrace)
        logging.error(message)
        conn.send((_MessageType.EXCEPTION, stacktrace))
    finally:
        conn.close()


[docs]def process_call(conn, env, flatten, action_spec): """ Returns: True: continue to work False: end the worker """ try: # Only block for short times to have keyboard exceptions be raised. while True: if conn.poll(0.1): break message, payload = conn.recv() except (EOFError, KeyboardInterrupt): return False if message == _MessageType.ACCESS: name = payload result = getattr(env, name) conn.send((_MessageType.RESULT, result)) elif message == _MessageType.CALL: name, args, kwargs = payload if flatten and name == 'step': args = [nest.pack_sequence_as(action_spec, args[0])] result = getattr(env, name)(*args, **kwargs) if flatten and name in ['step', 'reset']: result = nest.flatten(result) assert all([not isinstance(x, torch.Tensor) for x in result ]), ("Tensor result is not allowed: %s" % name) conn.send((_MessageType.RESULT, result)) elif message == _MessageType.CLOSE: assert payload is None env.close() return False else: raise KeyError('Received message of unknown type {}'.format(message)) return True
[docs]class ProcessEnvironment(object): def __init__(self, env_constructor, env_id=None, flatten=False, fast=False, num_envs=0, name=""): """Step environment in a separate process for lock free paralellism. The environment is created in an external process by calling the provided callable. This can be an environment class, or a function creating the environment and potentially wrapping it. The returned environment should not access global variables. Args: env_constructor (Callable): callable environment creator. env_id (torch.int32): ID of the the env flatten (bool): whether to assume flattened actions and time_steps during communication to avoid overhead. fast (bool): whether created by ``FastParallelEnvironment`` or not. num_envs (int): number of environments in the ``FastParallelEnvironment``. Only used if ``fast`` is True. name (str): name of the FastParallelEnvironment. Only used if ``fast`` is True. Attributes: observation_spec: The cached observation spec of the environment. action_spec: The cached action spec of the environment. time_step_spec: The cached time step spec of the environment. """ self._env_constructor = env_constructor self._flatten = flatten self._env_id = env_id self._observation_spec = None self._action_spec = None self._reward_spec = None self._time_step_spec = None self._env_info_spec = None self._conn = None self._fast = fast self._num_envs = num_envs self._name = name if fast: self._penv = _penv.ProcessEnvironmentCaller(env_id, name)
[docs] def start(self, wait_to_start=True): """Start the process. Args: wait_to_start (bool): Whether the call should wait for an env initialization. """ # The following context made sure that the newly created child process # (for environment) is started using the "fork" start method. # # This is to prevent multiprocessing from accidentally creating the # child process with the "spawn" start method. Using "fork" start method # is required here because we would like to have the child process # inherit the alf configurations from the parent process, so that such # configuration are effective for the to-be-created environments in the # child process. assert not self._conn, "Cannot start() ProcessEnvironment multiple times" mp_ctx = multiprocessing.get_context('fork') self._conn, conn = mp_ctx.Pipe() self._process = mp_ctx.Process( target=_worker, args=(conn, self._env_constructor, self._env_id, self._flatten, self._fast, self._num_envs, self._name)) atexit.register(self.close) self._process.start() if wait_to_start: self.wait_start()
[docs] def wait_start(self): """Wait for the started process to finish initialization.""" assert self._conn, "Run ProcessEnvironment.start() first" result = self._conn.recv() if isinstance(result, Exception): self._conn.close() self._process.join(5) raise result assert result == _MessageType.READY, result
[docs] def env_info_spec(self): if not self._env_info_spec: self._env_info_spec = self.call('env_info_spec')() return self._env_info_spec
[docs] def observation_spec(self): if not self._observation_spec: self._observation_spec = self.call('observation_spec')() return self._observation_spec
[docs] def action_spec(self): if not self._action_spec: self._action_spec = self.call('action_spec')() return self._action_spec
[docs] def reward_spec(self): if not self._reward_spec: self._reward_spec = self.call('reward_spec')() return self._reward_spec
[docs] def time_step_spec(self): if not self._time_step_spec: self._time_step_spec = self.call('time_step_spec')() return self._time_step_spec
def __getattr__(self, name): """Request an attribute from the environment. Note that this involves communication with the external process, so it can be slow. Args: name (str): Attribute to access. Returns: Value of the attribute. """ assert self._conn, "Run ProcessEnvironment.start() first" if self._fast: self._penv.call() self._conn.send((_MessageType.ACCESS, name)) return self._receive()
[docs] def call(self, name, *args, **kwargs): """Asynchronously call a method of the external environment. Args: name (str): Name of the method to call. *args: Positional arguments to forward to the method. **kwargs: Keyword arguments to forward to the method. Returns: Promise object that blocks and provides the return value when called. """ assert self._conn, "Run ProcessEnvironment.start() first" if self._fast: self._penv.call() payload = name, args, kwargs self._conn.send((_MessageType.CALL, payload)) return self._receive
[docs] def close(self): """Send a close message to the external process and join it.""" try: if self._fast: self._penv.close() else: self._conn.send((_MessageType.CLOSE, None)) self._conn.close() except IOError: # The connection was already closed. pass self._process.join()
[docs] def step(self, action, blocking=True): """Step the environment. Args: action (nested tensors): The action to apply to the environment. blocking (bool): Whether to wait for the result. Returns: time step when blocking, otherwise callable that returns the time step. """ promise = self.call('step', action) if blocking: return promise() else: return promise
[docs] def reset(self, blocking=True): """Reset the environment. Args: blocking (bool): Whether to wait for the result. Returns: New observation when blocking, otherwise callable that returns the new observation. """ promise = self.call('reset') if blocking: return promise() else: return promise
def _receive(self): """Wait for a message from the worker process and return its payload. Raises: Exception: An exception was raised inside the worker process. KeyError: The reveived message is of an unknown type. Returns: Payload object of the message. """ assert self._conn, "Run ProcessEnvironment.start() first" message, payload = self._conn.recv() # Re-raise exceptions in the main process. if message == _MessageType.EXCEPTION: stacktrace = payload raise Exception(stacktrace) elif message == _MessageType.RESULT: return payload elif message == _MessageType.CLOSE: # When notified that the child process is going to shut down, do not # panic and handle it quietly. return None self.close() raise KeyError( 'Received message of unexpected type {}'.format(message))
[docs] def render(self, mode='human'): """Render the environment. Args: mode (str): One of ['rgb_array', 'human']. Renders to an numpy array, or brings up a window where the environment can be visualized. Returns: An ndarray of shape [width, height, 3] denoting an RGB image if mode is `rgb_array`. Otherwise return nothing and render directly to a display window. Raises: NotImplementedError: If the environment does not support rendering. """ return self.call('render', mode)()