Source code for alf.summary.render

# Copyright (c) 2021 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import io
import numpy as np
import matplotlib
matplotlib.use('Agg')  # 'Agg' no need for xserver!
import matplotlib.pyplot as plt
# Style gallery: https://tonysyu.github.io/raw_content/matplotlib-style-gallery/gallery.html
plt.style.use('seaborn-dark')
try:
    import rpack
except ImportError:
    rpack = None

import cv2

import torch.distributions as td

import alf
import alf.nest as nest
from alf.utils import dist_utils
"""To use the rendering functions in this file, when playing a model, specify the
flags '--alg_render' and '--record_file'.

Also in your algorithm, put the rendered images in ``alg_step.info`` of
``predict_step()``.

Example:

.. code-block:: python

    import alf.summary.render as render

    action_dist, action = self._predict_action(time_step.observation)

    with alf.summary.scope(scope_name):
        action_img = render.render_action(
            name="action", action=action, action_spec=self._action_spec)
        action_heatmap = render.render_heatmap(
            name="action_heatmap", data=action, val_label="action")
        act_dist_curve = render.render_action_distribution(
            name="action_dist", act_dist=action_dist, action_spec=self._action_spec)

        return AlgStep(
            output=action,
            info=dict(
                action_img=action_img,
                action_heatmap=action_heatmap,
                action_dist_curve=act_dist_curve))
"""


[docs]class Image(object): """A simple image class.""" def __init__(self, img): """ Args: img (np.ndarray): a numpy array image of shape ``[H,W]`` (gray-scale) or ``[H,W,3]`` (RGB). """ assert isinstance(img, np.ndarray), "Image must be a numpy array!" shape = img.shape assert (len(shape) == 2) or (len(shape) == 3 and shape[-1] == 3), ( "Image shape should be [H,W] or [H,W,3]!") if len(shape) == 2: self._img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) else: self._img = img @property def shape(self): """Return the shape of the image.""" return self._img.shape @property def data(self): """Return the image numpy array which is always RGB.""" return self._img
[docs] def resize(self, height=None, width=None): """Resize the image in-place given the desired width and/or height. Args: height (int): the desired output image height. If ``None``, this will be scaled to keep the original aspect ratio if ``width`` is provided. width (int): the desired output image width. If ``None``, this will be scaled to keep the original aspect ratio if ``height`` is provided. """ if width is not None and height is not None: self._img = cv2.resize(self._img, dsize=(width, height)) return if width is not None: scale = float(width) / self._img.shape[1] elif height is not None: scale = float(height) / self._img.shape[0] else: raise ValueError('At least width or height should be provided.') self._img = cv2.resize( self._img, dsize=(0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
[docs] @classmethod def from_pyplot_fig(cls, fig, dpi=200): """Generate an ``Image`` instance from a pyplot figure instance. Args: fig (pyplot.figure): a pyplot figure instance dpi (int): resolution of the generated image Returns: Image: """ buf = io.BytesIO() fig.savefig(buf, format="png", dpi=dpi, bbox_inches="tight") buf.seek(0) img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8) buf.close() img = cv2.imdecode(img_arr, 1) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return cls(img)
[docs] @classmethod def pack_image_nest(cls, imgs): """Given a nest of images, pack them into a larger image so that it has an area as small as possible. This problem is generally known as "rectangle packing" and its optimal solution is `NP-complete <https://en.wikipedia.org/wiki/Rectangle_packing>`_. Here we just rely on a third-party lib `rpack <https://pypi.org/project/rectangle-packer/>`_ that is used for building CSS sprites, for an approximate solution. Args: imgs (nested Image): a nest of ``Image`` instances Returns: Image: the big mosaic image """ assert rpack is not None, "You need to install rectangle-packer first!" imgs = nest.flatten(imgs) if len(imgs) == 0: return # first get all images' sizes (w,h) sizes = [(i.shape[1], i.shape[0]) for i in imgs] # call rpack for an approximate solution: [(x,y),...] positions positions = rpack.pack(sizes) # compute the height and width of the enclosing rectangle H, W = 0, 0 for size, pos in zip(sizes, positions): H = max(H, pos[1] + size[1]) W = max(W, pos[0] + size[0]) packed_img = np.full((H, W, 3), 255, dtype=np.uint8) for pos, img in zip(positions, imgs): packed_img[pos[1]:pos[1] + img.shape[0], pos[0]:pos[0] + img.shape[1], :] = img.data return cls(packed_img)
[docs] @classmethod def stack_images(cls, imgs, horizontal=True): """Given a list/tuple of images, stack them in order either horizontally or vertically. Args: imgs (list[Image]|tuple[Image]): a list/tuple of ``Image`` instances horizontal (bool): if True, stack images horizontally, otherwise vertically. Returns: Image: the stacked big image """ assert isinstance(imgs, (list, tuple)) if horizontal: H = max([i.shape[0] for i in imgs]) W = sum([i.shape[1] for i in imgs]) stacked_img = np.full((H, W, 3), 255, dtype=np.uint8) offset_w = 0 for i in imgs: stacked_img[:i.shape[0], offset_w:offset_w + i.shape[1], :] = i.data offset_w += i.shape[1] else: H = sum([i.shape[0] for i in imgs]) W = max([i.shape[1] for i in imgs]) stacked_img = np.full((H, W, 3), 255, dtype=np.uint8) offset_h = 0 for i in imgs: stacked_img[offset_h:offset_h + i.shape[0], :i.shape[1], :] = i.data offset_h += i.shape[0] return cls(stacked_img)
_rendering_enabled = False
[docs]def enable_rendering(flag=True): """Enable rendering by ``flag``. Args: flag (bool): True to enable, False to disable """ global _rendering_enabled _rendering_enabled = flag
[docs]def is_rendering_enabled(): """Return whether rendering is enabled.""" return _rendering_enabled
def _rendering_wrapper(rendering_func): """A wrapper function to gate the rendering function based on if rendering is enabled, and if yes generate a scoped rendering identifier before calling the rendering function. It re-uses the scope stack in ``alf.summary.summary_ops.py``. """ @functools.wraps(rendering_func) def wrapper(name, data, **kwargs): if is_rendering_enabled(): name = alf.summary.summary_ops._scope_stack[-1] + name return rendering_func(name, data, **kwargs) return wrapper def _convert_to_image(name, fig, dpi, height=None, width=None): """First putting the rendering identifier on top of the figure and then convert it to an instance of ``Image``. Also release the resources of ``fig``. Args: name (str): a scoped identifier fig (pyplot.figure): the figure holding the rendering dpi (int): resolution of each rendered image height (int): height of the output image width (int): width of the output image """ fig.suptitle(name) img = Image.from_pyplot_fig(fig, dpi=dpi) if height is not None and width is not None: img.resize(height=height, width=width) plt.close(fig) return img def _heatmap(data, row_ticks=None, col_ticks=None, row_labels=None, col_labels=None, ax=None, cbar_kw={}, cbarlabel="", **kwargs): """Create a heatmap from a numpy array and two lists of labels. (Code from `matplotlib documentation <https://matplotlib.org/stable/gallery/images_contours_and_fields/image_annotated_heatmap.html>`_) Args: data (np.ndarray): A 2D numpy array of shape ``[H, W]``. row_ticks (list[float]): List of row (y-axis) tick locations. col_ticks (list[float]): List of column (x-axis) tick locations. row_labels (list[str]): A list labels for the rows. Its length should be equal to that of ``row_ticks`` if ``row_ticks`` is not None. Otherwise, it should have a length of ``H``. col_labels (list[str]): A list of labels for the columns. Its length should be equal to that of ``col_ticks`` if ``col_ticks`` is not None. Otherwise, it should have a length of ``W``. ax (matplotlib.axes.Axes): instance to which the heatmap is plotted. If None, use current axes or create a new one. cbar_kw (dict): A dictionary with arguments to ``matplotlib.Figure.colorbar``. cbarlabel (str): The label for the colorbar. **kwargs: All other arguments that are forwarded to ``ax.imshow``. Returns: tuple: - matplotlib.image.AxesImage: the heatmap image - matplotlib.pyplot.colorbar: the colorbar of the heatmap """ if not ax: ax = plt.gca() # Plot the heatmap im = ax.imshow(data, **kwargs) # Create colorbar cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw) cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom") if col_ticks is None: # show all the ticks by default col_ticks = np.arange(data.shape[1] + 1) - .5 ax.set_xticks(col_ticks, minor=True) if row_ticks is None: # show all the ticks by default row_ticks = np.arange(data.shape[0] + 1) - .5 ax.set_yticks(row_ticks, minor=True) # ... and label them with the respective list entries. if col_labels is not None: assert len(col_ticks) == len(col_labels), ( "'col_ticks' should have the " "same length as 'col_labels'") ax.set_xticklabels(col_labels) if row_labels is not None: assert len(row_ticks) == len(row_labels), ( "'row_ticks' should have the " "same length as 'row_labels'") ax.set_yticklabels(row_labels) # Let the horizontal axes labeling appear on top. ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) # Rotate the tick labels and set their alignment. plt.setp( ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor") # Turn spines off and create white grid. ax.spines[:].set_visible(False) ax.grid(which="minor", color="w", linestyle='-', linewidth=3) ax.tick_params(which="minor", bottom=False, left=False) return im, cbar def _annotate_heatmap(im, valfmt="%.2f", textcolors=("black", "white"), threshold=None, **textkw): """A function to annotate a heatmap. (Code from `matplotlib documentation <https://matplotlib.org/stable/gallery/images_contours_and_fields/image_annotated_heatmap.html>`_) Args: im (matplotlib.image.AxesImage): The image to be labeled. valfmt (str): The format of the annotations inside the heatmap. This should either use the string format method, e.g. "%.2f", or be a ``matplotlib.ticker.Formatter``. textcolors (tuple[str]): A pair of colors. The first is used for values below a threshold, the second for those above. threshold (float): Value in data units according to which the colors from textcolors are applied. If None (the default) uses the middle of the colormap as separation. **textkw: All other arguments are forwarded to each call to ``text()`` used to create the text labels. """ data = im.get_array() # Normalize the threshold to the images color range. if threshold is not None: threshold = im.norm(threshold) else: threshold = im.norm(data.max()) / 2. # Set default alignment to center, but allow it to be # overwritten by textkw. kw = dict(horizontalalignment="center", verticalalignment="center") kw.update(textkw) # Loop over the data and create a `Text` for each "pixel". # Change the text's color depending on the data. for i in range(data.shape[0]): for j in range(data.shape[1]): kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)]) im.axes.text(j, i, valfmt % data[i, j], **kw)
[docs]@_rendering_wrapper def render_heatmap(name, data, val_label="", row_ticks=None, col_ticks=None, row_labels=None, col_labels=None, cbar_kw={}, annotate_format="%.2f", font_size=7, img_height=None, img_width=None, dpi=300, figsize=(2, 2), **kwargs): """Render a 2D tensor as a heatmap. Args: name (str): rendering identifier data (Tensor|np.ndarray): a tensor/np.array of shape ``[H, W]`` val_label (str): The label for the rendered values. row_ticks (list[float]): List of row (y-axis) tick locations. col_ticks (list[float]): List of column (x-axis) tick locations. row_labels (list[str]): A list labels for the rows. Its length should be equal to that of ``row_ticks`` if ``row_ticks`` is not None. Otherwise, it should have a length of ``H``. col_labels (list[str]): A list of labels for the columns. Its length should be equal to that of ``col_ticks`` if ``col_ticks`` is not None. Otherwise, it should have a length of ``W``. cbar_kw (dict): A dictionary with arguments to ``matplotlib.Figure.colorbar``. annotate_format (str): The format of the annotations on the heatmap to show the actual value represented by each heatmap cell. This should either use the string format method, e.g. "%.2f", or be a ``matplotlib.ticker.Formatter``. No annotation on the heatmap if this argument is ''. font_size (int): the font size of annotation on the heatmap img_height (int): height of the output image img_width (int): width of the output image dpi (int): resolution of each rendered image figsize (tuple[int]): figure size. For the relationship between ``dpi`` and ``figsize``, please refer to `this post <https://stackoverflow.com/questions/47633546/relationship-between-dpi-and-figure-size>`_. **kwargs: All other arguments that are forwarded to ``ax.imshow``. For example, to specify the value range on the heatmap, we can use ``vmin`` and ``vmax``. Returns: Image: an output image rendered for the tensor """ assert len(data.shape) == 2, "Must be a rank-2 tensor!" if not isinstance(data, np.ndarray): array = data.cpu().numpy() else: array = data fig, ax = plt.subplots(figsize=figsize) im, _ = _heatmap( array, row_ticks, col_ticks, row_labels, col_labels, ax, cbar_kw=cbar_kw, cbarlabel=val_label, **kwargs) if annotate_format != '': _annotate_heatmap(im, valfmt=annotate_format, size=font_size) return _convert_to_image(name, fig, dpi, img_height, img_width)
[docs]@_rendering_wrapper def render_contour(name, data, x_ticks=None, y_ticks=None, x_label=None, y_label=None, font_size=7, img_height=None, img_width=None, dpi=300, figsize=(2, 2), flip_y_axis=True, **kwargs): """Render a 2D tensor as a contour. Args: name (str): rendering identifier data (Tensor|np.ndarray): a tensor/np.array of shape ``[H,W]``. Note that the rows of ``data`` correspond to y (inverted) and columns correspond to x in the contour figure. x_ticks (np.array): A list of length ``W`` with x ticks. y_ticks (np.array): A list (from 0 to H-1) of length ``H`` with y ticks. x_label (str): label shown besides x-axis y_label (str): label shown besides y-axis font_size (int): font size for the numbers on the contour img_height (int): height of the output image img_width (int): width of the output image dpi (int): resolution of each rendered image figsize (tuple[int]): figure size. For the relationship between ``dpi`` and ``figsize``, please refer to `this post <https://stackoverflow.com/questions/47633546/relationship-between-dpi-and-figure-size>`_. flip_y_axis (bool): whether flip the y axis. Flipping makes this consistent with heatmap regarding y axis. **kargs: All other arguments that are forwarded to ``ax.contour``. Returns: Image: an output image rendered for the tensor """ assert len(data.shape) == 2, "Must be a rank-2 tensor!" if not isinstance(data, np.ndarray): array = data.cpu().numpy() else: array = data fig, ax = plt.subplots(figsize=figsize) # x must be dim 0 for ax.contour() array = np.transpose(array, (0, 1)) if x_ticks is None: x_ticks = np.arange(len(array)) if y_ticks is None: y_ticks = np.arange(len(array[0])) ct = ax.contour(x_ticks, y_ticks, array, **kwargs) ax.clabel(ct, inline=True, fontsize=font_size) if x_label: ax.set_xlabel(x_label) if y_label: ax.set_ylabel(y_label) if flip_y_axis: plt.gca().invert_yaxis() return _convert_to_image(name, fig, dpi, img_height, img_width)
[docs]@_rendering_wrapper def render_curve(name, data, x_range=None, y_range=None, x_label=None, y_label=None, legends=None, legend_kwargs={}, img_height=None, img_width=None, dpi=300, figsize=(2, 2), **kwargs): """Plot 1D curves. Args: name (stor): rendering identifier data (Tensor|np.ndarray): a rank-1 or rank-2 tensor/np.array. If rank-2, then each row represents an individual curve. x_range (tuple[float]): min/max for x values. If None, ``x`` is the index sequence of curve points. If provided, ``x`` is evenly spaced by ``(x_range[1] - x_range[0]) / (N - 1)``. y_range (tuple[float]): a tuple of ``(min_y, max_y)`` for showing on the figure. If None, then it will be decided according to the ``y`` values. Note that this range won't change ``y`` data; it's only used by matplotlib for drawing ``y`` limits. x_label (str): shown besides x-axis y_label (str): shown besides y-axis legends (list[str]): label for each curve. No legends are shown if None. legend_kwargs (dict): optional legend kwargs img_height (int): height of the output image img_width (int): width of the output image dpi (int): resolution of each rendered image figsize (tuple[int]): figure size. For the relationship between ``dpi`` and ``figsize``, please refer to `this post <https://stackoverflow.com/questions/47633546/relationship-between-dpi-and-figure-size>`_. **kwargs: all other arguments to ``ax.plot()``. Returns: Image: an output image rendered for the tensor """ assert len(data.shape) in (1, 2), "Must be rank-1 or rank-2!" if not isinstance(data, np.ndarray): array = data.cpu().numpy() else: array = data if len(array.shape) == 1: array = np.expand_dims(array, 0) fig, ax = plt.subplots(figsize=figsize) M, N = array.shape x = range(N) if x_range is not None: delta = (x_range[1] - x_range[0]) / float(N - 1) x = delta * x + x_range[0] for i in range(M): ax.plot(x, array[i], **kwargs) if legends is not None: ax.legend(legends, loc="best", **legend_kwargs) if y_range: ax.set_ylim(y_range) if x_label: ax.set_xlabel(x_label) if y_label: ax.set_ylabel(y_label) return _convert_to_image(name, fig, dpi, img_height, img_width)
[docs]@_rendering_wrapper def render_bar(name, data, width=0.8, y_range=None, x_ticks=None, x_label=None, y_label=None, legends=None, legend_kwargs={}, annotate_format="%.2f", img_height=None, img_width=None, dpi=300, figsize=(2, 2), **kwargs): """Render bar plots. Args: name (str): rendering identifier data (Tensor|np.ndarray): a rank-1 or rank-2 tensor/np.array. Each value is the height of a bar. If rank-2, each row represents an array of bars. Bars of multiple rows will stack on each other. width (float): bar width y_range (tuple[float]): a tuple of ``(min_y, max_y)`` for showing on the figure. If None, then it will be decided according to the ``y`` values. x_ticks (list[float]): x ticks shown along x axis x_label (str): shown besides x-axis y_label (str): shown besides y-axis legends (list[str]): label for each curve. No legends are shown if None. legend_kwargs (dict): optional legend kwargs annotate_format (str): The format of the annotations on the bars to show the actual value represented by each bar. This should either use the string format method, e.g. "%.2f", or be a ``matplotlib.ticker.Formatter``. img_height (int): height of the output image img_width (int): width of the output image dpi (int): resolution of each rendered image figsize (tuple[int]): figure size. For the relationship between ``dpi`` and ``figsize``, please refer to `this post <https://stackoverflow.com/questions/47633546/relationship-between-dpi-and-figure-size>`_. **kwargs: all other arguments to ``ax.bar()``. Returns: Image: an output image rendered for the tensor """ assert len(data.shape) in (1, 2), "Must be rank-1 or rank-2!" if not isinstance(data, np.ndarray): array = data.cpu().numpy() else: array = data if len(array.shape) == 1: array = np.expand_dims(array, 0) fig, ax = plt.subplots(figsize=figsize) M, N = array.shape x = range(N) for i in range(M): if legends: p = ax.bar(x, array[i], width, label=legends[i], **kwargs) else: p = ax.bar(x, array[i], width, **kwargs) ax.bar_label(p, label_type="center", fmt=annotate_format) ax.axhline(0, color='grey', linewidth=1) if legends: ax.legend(legends, loc="best", **legend_kwargs) if x_ticks is not None: ax.set_xticks(x_ticks) if y_range: ax.set_ylim(y_range) if x_label: ax.set_xlabel(x_label) if y_label: ax.set_ylabel(y_label) return _convert_to_image(name, fig, dpi, img_height, img_width)
[docs]@_rendering_wrapper def render_text(name: str, data: str, font_size: int = 10, fig_width_per_char: float = 0.1, fig_height: float = 0.4, img_height: int = None, img_width: int = None, dpi=200, **kwargs): """Render a text string. Args: name: name of the text data: the string to be rendered font_size: text font size fig_width_per_char: the width of each character measured by ``figsize`` of ``plt.subplots()``. fig_height: the height of the text label measured by ``figsize`` of ``plt.subplots()``. img_height (int): height of the output image img_width (int): width of the output image **kwargs: extra arguments forwarded to ``ax.text``. """ fig, ax = plt.subplots( figsize=(len(data) * fig_width_per_char, fig_height)) kwargs['fontsize'] = font_size ax.text(0, 0, data, **kwargs) ax.axis('off') return _convert_to_image(name, fig, dpi, img_height, img_width)
[docs]def render_action(name, action, action_spec, **kwargs): """An action renderer that plots agent's action at one time step in a bar plot. Args: name (str): rendering identifier action (nested Tensor): a nested tensor where each element is a rank-1 (discrete) or rank-2 (continuous) tensor of batch size 1. action_spec (nested TensorSpec): a nested tensor spec with the same structure with ``action``. **kwargs: all other arguments will be directed to ``render_bar()``. Returns: nested Image: a structure same with ``action`` """ def _render_action(path, act, spec): y_range = None if isinstance(spec, alf.tensor_specs.BoundedTensorSpec): bound = (np.min(spec.minimum), np.max(spec.maximum)) if all(map(np.isfinite, bound)): y_range = bound if spec.is_discrete: fmt = "%d" else: fmt = "%.2f" x_ticks = range(act.shape[-1]) name_ = name if path == '' else name + '/' + path return render_bar( name_, act, y_range=y_range, annotate_format=fmt, x_ticks=x_ticks, **kwargs) return nest.py_map_structure_with_path(_render_action, action, action_spec)
[docs]def render_action_distribution(name, act_dist, action_spec, n_samples=500, n_bins=20, **kwargs): """An action distribution renderer that plots agent's action distribution at one time step in a curve plot. Assuming action dims are independent, each action dim's 1D distribution corresponds to a separate curve in the plot. Args: name (str): rendering identifier act_dist (Distribution): a nested tensor where each element is a action distribution of batch size 1. action_spec (nested TensorSpec): a nested tensor spec with the same structure with ``act_dist`` n_samples (int): number of samples for approximation n_bins (int): how many histogram bins used for approximation **kwargs: all other arguments will be directed to ``render_curve()`` """ def _approximate_probs(dist, x_range): """Given a 1D continuous distribution, sample a bunch of points to form a histogram to approximate the distribution curve. The values of the histogram are densities (integral equal to 1 over the bin range). Args: dist (Distribution): action distribution whose param is rank-2 x_range (tuple[float]): a tuple of ``(min_x, max_x)`` for the domain of the distribution. Returns: np.array: a 2D matrix where each row is a prob hist for a dim """ mode = dist_utils.get_mode(dist) assert len( mode.shape) == 2, "Currently only support rank-2 distributions!" dim = mode.shape[-1] points = dist.sample(sample_shape=(n_samples, )).cpu().numpy() points = np.reshape(points, (-1, dim)) probs = [] for d in range(dim): hist, _ = np.histogram( points[:, d], bins=n_bins, density=True, range=x_range) probs.append(hist) return np.stack(probs) def _render_act_dist(path, dist, spec): if spec.is_discrete: assert isinstance(dist, td.categorical.Categorical) probs = dist.probs.reshape(-1).cpu().numpy() x_range, legends = None, None else: x_range = (np.min(spec.minimum), np.max(spec.maximum)) probs = _approximate_probs(dist, x_range) legends = ["d%s" % i for i in range(probs.shape[0])] name_ = name if path == '' else name + '/' + path return render_curve( name=name_, data=probs, legends=legends, x_range=x_range) return nest.py_map_structure_with_path(_render_act_dist, act_dist, action_spec)