# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from gym.wrappers.monitoring.video_recorder import VideoRecorder as GymVideoRecorder
from gym import error, logger
import alf
try:
# There might be importing errors about matplotlib on the cluster if the
# correct version (requires python3.7) of matplotlib is not installed.
# In this case we just skip importing because no rendering is needed on
# cluster.
import alf.summary.render as render
except ImportError:
render = None
[docs]@alf.configurable(whitelist=[
'frame_max_width', 'frames_per_sec', 'last_step_repeats',
'append_blank_frames'
])
class VideoRecorder(GymVideoRecorder):
"""A video recorder that renders frames and encodes them into a video file.
Besides rendering frames, it also supports plotting prediction info. Each
algorithm is responsible for adding rendered Image instances in its pred info
in order to be recorded here. See the docstring in ``alf.summary.render``
for more details.
"""
def __init__(self,
env,
frame_max_width=2560,
frames_per_sec=None,
last_step_repeats=0,
append_blank_frames=0,
**kwargs):
"""
Args:
env (Gym.env):
frame_max_width (int): the max width of a video frame. Scale if the
original width is bigger than this.
frames_per_sec (fps): if None, use fps from the env
last_step_repeats (int): repeat such number of times for the
last frame of each episode.
append_blank_frames (int): If >0, will append such number of blank
frames at the end of the episode in the rendered video file.
A negative value has the same effects as 0 and no blank frames
will be appended.
"""
super(VideoRecorder, self).__init__(env=env, **kwargs)
self._frame_width = frame_max_width
if frames_per_sec is not None:
self.frames_per_sec = frames_per_sec # overwrite the base class
self._last_step_repeats = last_step_repeats
self._append_blank_frames = append_blank_frames
self._blank_frame = None
self._pred_info_img_shapes = None
# two caches for lazy rendering of environmental frames and pred_info
self._frame_cache = []
self._pred_info_cache = []
[docs] def capture_frame(self, pred_info=None, is_last_step=False):
"""Render ``self.env`` and add the resulting frame to the video. Also
plot Image instances extracted from prediction info of ``policy_step``.
Args:
pred_info (None|nest): prediction step info for displaying: any Image
instance in the info nest will be recorded.
is_last_step (bool): whether the current time step is the last
step of the episode, either due to game over or time limits.
"""
if not self.functional: return
logger.debug('Capturing video frame: path=%s', self.path)
if pred_info is not None:
assert not self.ansi_mode, "Only supports rgb_array mode!"
render_mode = 'rgb_array'
else:
render_mode = 'ansi' if self.ansi_mode else 'rgb_array'
frame = self.env.render(mode=render_mode)
if frame is None:
if self._async:
return
else:
# Indicates a bug in the environment: don't want to raise
# an error here.
logger.warn(
'Env returned None on render(). Disabling further '
'rendering for video recorder by marking as disabled: '
'path=%s metadata_path=%s', self.path, self.metadata_path)
self.broken = True
else:
frame = self._plot_pred_info(frame, pred_info)
self._encode_frame(frame)
if is_last_step:
if self._last_step_repeats > 0:
for _ in range(self._last_step_repeats):
self._encode_frame(frame)
if self._append_blank_frames > 0:
if self._blank_frame is None:
self._blank_frame = np.zeros_like(frame)
for _ in range(self._append_blank_frames):
self._encode_frame(self._blank_frame)
assert not self.broken, (
"The output file is broken! Check warning messages.")
[docs] def capture_env_frame(self):
"""Return un-encoded env frame
"""
if not self.functional: return
logger.debug('Capturing video frame: path=%s', self.path)
render_mode = 'rgb_array'
frame = self.env.render(mode=render_mode)
assert frame is not None
return frame
[docs] def cache_frame_and_pred_info(self, frame, pred_info=None):
"""Cache the input frame and pred_info for video generation later.
Args:
frame (np.array): the environmental frame.
pred_info (None|nest): prediction step info for displaying: any Image
instance in the info nest will be recorded.
"""
self._frame_cache.append(frame)
self._pred_info_cache.append(pred_info)
[docs] def clear_cache(self):
"""Clear the cached contents.
"""
self._frame_cache = []
self._pred_info_cache = []
[docs] def generate_video_from_cache(self):
"""Generate the video from the cached frames. Also add the plot Image
instances extracted from cached prediction info.
The cache will be reset to empty afterwards.
"""
for i, (frame, pred_info) in enumerate(
zip(self._frame_cache, self._pred_info_cache)):
frame = self._plot_pred_info(frame, pred_info)
self._encode_frame(frame)
self.clear_cache()
if self._append_blank_frames > 0:
if self._blank_frame is None:
self._blank_frame = np.zeros_like(frame)
for _ in range(self._append_blank_frames):
self._encode_frame(self._blank_frame)
assert not self.broken, (
"The output file is broken! Check warning messages.")
def _encode_frame(self, frame):
"""Perform encoding of the input frame
Args:
frame(np.ndarray|str|StringIO): the frame to be encoded,
which is of type ``str`` or ``StringIO`` if ``ansi_mode`` is
True, and ``np.array`` otherwise.
"""
if self.ansi_mode:
self._encode_ansi_frame(frame)
else:
self._encode_image_frame(frame)
def _plot_pred_info(self, env_frame, pred_info):
r"""Search ``Image`` elements in ``pred_info``, merge them into a big
image, and stack it with ``env_frame``.
Args:
env_frame (numpy.ndarray): ``numpy.ndarray`` with shape
``(H, W, 3)``, representing RGB values for an :math:`H\times W`
image, output from ``env.render('rgb_array')``.
pred_info (nested): a nest. Any element that is ``Image`` will be
retrieved.
Returns:
np.ndarray:
"""
imgs = [
i for i in alf.nest.flatten(pred_info)
if isinstance(i, render.Image)
]
if self._pred_info_img_shapes is None:
self._pred_info_img_shapes = [i.shape for i in imgs]
else:
# Sometimes pyplot will automatically calculate figure sizes if no
# image height and widths are provided when rendering, which
# could result in a different rectangle packing result.
# In order to not break the video encoder, we need to make sure each
# pred info img has the same size with before.
for i, shape in zip(imgs, self._pred_info_img_shapes):
if i.shape != shape:
i.resize(*shape[:2])
frame = render.Image(env_frame)
if imgs:
info_img = render.Image.pack_image_nest(imgs)
# always put env frame on top/left; for simplicity here we generate
# both and compare their sizes.
horizontal = render.Image.stack_images([frame, info_img],
horizontal=True)
vertical = render.Image.stack_images([frame, info_img],
horizontal=False)
if np.product(horizontal.shape) < np.product(vertical.shape):
frame = horizontal
else:
frame = vertical
if frame.shape[1] > self._frame_width:
frame.resize(width=self._frame_width)
return frame.data