Source code for alf.utils.visualizer

# Copyright (c) 2021 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Various functions related to visualizations of networks etc."""

import torch
import alf.nest as nest


[docs]def critic_network_visualizer(net, observation, action_upper_left, action_upper_right, action_lower_left, H=20, W=20, batch_size=None): """Generate a batched network response image within the rectangular range of actions (referred to as probing region) specified by ``action_top_left``, ``action_top_right``, ``action_bottom_left`` as shown below: ``action_upper_left`` -----> ``action_upper_right`` | | | | v | ``action_lower_left``---- ``action_lower_right`` where ``action_lower_right`` is computed from the three provided points as the following because of the rectangular assumption: ``action_lower_right = (action_upper_right + action_lower_left - action_upper_left)`` Example usage: .. code-block:: python # assume a case where the dimensionality of action is 4 # the action for the upper-left point of the probing region action_upper_left = torch.Tensor([1, -1, 0, 0]) # the action for the upper-right point of the probing region action_upper_right = torch.Tensor([1, 1, 0, 0]) # the action for the lower-left point of the probing region action_lower_left = torch.Tensor([-1, -1, 0, 0]) # define a network function def net_func(net_input): critics, _ = self._critic_networks( net_input) # [B, replicas * reward_dim] critics = critics.reshape( # [B, replicas, reward_dim] -1, self._num_critic_replicas, *self._reward_spec.shape) critics = critics.min(dim=1)[0] return critics img = critic_network_visualizer(net_func, inputs.observation, action_upper_left, action_upper_right, action_lower_left, 20, 20) # visualize the first response image in the batch data = img[0, ...].squeeze(0) data = data.cpu().numpy() import alf.summary.render as render val_img = render.render_heatmap(name="val_img", data=data) Args: net (Callable): a callable that is called as``net((obsevation, actions))`` observation (Tensor): [B, ...] action_upper_left (tensor): tensor representing the upper-left point of the probing region, with the shape of [action_dim] action_upper_right (tensor): a tensor representing the upper-right point of the probing region, with the shape of [action_dim] action_lower_left (tensor): a tensor representing the lower-left point of the probing region, with the shape of [action_dim] H (int): number of samples to be used for creating visualization along the direction of ``action_lower_left - action_upper_left``. W (int): number of samples to be used for creating visualization along the direction of ``action_upper_right - action_upper_left``. The total number of samples is H * W. batch_size (int): the batch size of the input ``observation``. If None, will be inferred from the input ``observation``. Returns: The network response image of the shape [B, K, H, W], where K denotes the dimensionality of the network output for the non-batch dimension. """ if batch_size is None: batch_size = nest.get_nest_size(observation, dim=0) # total number of anchors for probing the network num_anchors = H * W # expand observation from [B, ...] to [B * num_anchors, ...] ext_obs = nest.map_structure( lambda obs: torch.repeat_interleave(obs, num_anchors, dim=0), observation) action_lower_right = action_upper_right + action_lower_left - action_upper_left # create the mini-image: [action_dim, 2, 2] # [action_upper_left, action_upper_right] # [action_lower_left, action_lower_right] mini_image = torch.stack( (torch.stack((action_upper_left, action_lower_left), dim=1), torch.stack((action_upper_right, action_lower_right), dim=1)), dim=2) # [action_dim, 2, 2] -> [1, action_dim, 2, 2] -> [1, action_dim, H, W] mini_image = torch.nn.functional.interpolate( mini_image.unsqueeze(0), size=(H, W), mode='bilinear', align_corners=True) # [action_dim, H, W] mini_image = mini_image.squeeze(0) # [action_dim, H, W] -> [action_dim, H * W] -> [H * W, action_dim] anchors = mini_image.reshape(-1, num_anchors).transpose(0, 1) # expand action from [H * W, action_dim] to [B * H * W, action_dim] other_dims = [1] * (anchors.dim() - 1) ext_anchors = anchors.repeat(batch_size, *other_dims) network_inputs = (ext_obs, ext_anchors) out = net(network_inputs) # q_values: [B * H * W, K] -> [B, H, W, K], where K denotes the # dimensionality of the network output (non-batch dim) value_image = out.reshape(batch_size, H, W, -1) # [B, H, W, K] -> [B, K, H, W] value_image = value_image.permute(0, 3, 1, 2) return value_image