# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Summary related functions."""
import functools
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from typing import Callable
try:
# If tensorflow has been installed, pytorch might use tensorflow's
# tensorboard. In this case, gfile needs to be redirected if embedding
# projector is to be used.
# https://github.com/pytorch/pytorch/issues/30966#issuecomment-582747929
import tensorflow as tf
import tensorboard as tb
tf.io.gfile = tb.compat.tensorflow_stub.io.gfile
except:
pass
_summary_enabled = False
_summarize_output = False
_default_writer: SummaryWriter = None
_global_counter = np.array(0, dtype=np.int64)
_scope_stack = ['']
_record_if_stack = [
lambda: True,
]
_summary_writer_stack = [None]
# The default number of bins for histogram
_default_bins = 30
[docs]class scope(object):
"""A context manager for prefixing summary names.
Example:
```
with alf.summary.scope("root"):
alf.summary.scalar("val", 1) # tag is "root/val"
with alf.summary.scope("train"):
alf.summary.scalar("val", 1) # tag is "root/train/val"
```
"""
def __init__(self, name: str):
"""Create the context manager.
Args:
name (str): name of the scope
"""
name.strip('/')
self._name = name
@property
def name(self):
"""Get the name of the scope."""
return self._name
def __enter__(self):
scope_name = _scope_stack[-1] + self._name + '/'
_scope_stack.append(scope_name)
return scope_name
def __exit__(self, type, value, traceback):
_scope_stack.pop()
_SUMMARY_DATA_BUFFER = {}
def _summary_wrapper(summary_func):
"""Summary wrapper
Wrapper summary function to reduce cost for data computation
"""
@functools.wraps(summary_func)
def wrapper(name,
data,
average_over_summary_interval=False,
step=None,
**kwargs):
"""
Args:
average_over_summary_interval: if True, the average value of data during a
summary interval will be written to summary. If data is None,
it will be ignored for calculating the average. Note that providing
a "None" value for data is different from not calling the summary
function at all. A "None" value for data will cause the summary
to be generated if ``should_record_summaries()`` returns True
at the moment.
"""
if average_over_summary_interval:
if isinstance(data, torch.Tensor):
data = data.detach()
if name.startswith('/'):
name = name[1:]
else:
name = _scope_stack[-1] + name
if data is not None:
if name in _SUMMARY_DATA_BUFFER:
data_sum, counter = _SUMMARY_DATA_BUFFER[name]
_SUMMARY_DATA_BUFFER[name] = data_sum + data, counter + 1
else:
_SUMMARY_DATA_BUFFER[name] = data, 1
if should_record_summaries() and name in _SUMMARY_DATA_BUFFER:
data_sum, counter = _SUMMARY_DATA_BUFFER[name]
del _SUMMARY_DATA_BUFFER[name]
data = data_sum / counter
if step is None:
step = _global_counter
summary_func(name, data, step, **kwargs)
else:
if should_record_summaries():
if isinstance(data, torch.Tensor):
data = data.detach()
if step is None:
step = _global_counter
if name.startswith('/'):
name = name[1:]
else:
name = _scope_stack[-1] + name
summary_func(name, data, step, **kwargs)
return wrapper
[docs]def scope_name():
"""Get the full name of the current summary scope."""
return _scope_stack[-1]
[docs]@_summary_wrapper
def images(name, data, step=None, dataformat='NCHW', walltime=None):
"""Add image data to summary.
Args:
name (str): Data identifier
data (Tensor | numpy.array): image data
step (int): Global step value to record. None for using ``get_global_counter()``
dataformat (str): one of ('NCHW', 'NHWC', 'CHW', 'HWC', 'HW', 'WH')
walltime (float): Optional override default walltime (time.time())
seconds after epoch of event
"""
_summary_writer_stack[-1].add_images(
name, data, step, walltime=walltime, dataformats=dataformat)
[docs]@_summary_wrapper
def text(name, data, step=None, walltime=None):
"""Add text data to summary.
Note that the actual tag will be `name + "/text_summary"` because torch
adds "/text_summary to tag. See
https://github.com/pytorch/pytorch/blob/877ab3afe33eeaa797296d2794317b59e5ac90f4/torch/utils/tensorboard/summary.py#L477
Args:
name (str): Data identifier
data (str): String to save
step (int): Global step value to record. None for using get_global_counter()
walltime (float): Optional override default walltime (time.time())
seconds after epoch of event
"""
_summary_writer_stack[-1].add_text(name, data, step, walltime=walltime)
[docs]@_summary_wrapper
def scalar(name, data, step=None, walltime=None):
"""Addd scalar data to summary.
Note that data will be changed to float value (i.e. possible loss of
precision). See
https://github.com/pytorch/pytorch/blob/877ab3afe33eeaa797296d2794317b59e5ac90f4/torch/utils/tensorboard/summary.py#L175
Args:
name (str): Data identifier
data (float): Value to save
step (int): Global step value to record. None for using get_global_counter()
walltime (float): Optional override default walltime (time.time())
seconds after epoch of event
"""
_summary_writer_stack[-1].add_scalar(name, data, step, walltime=walltime)
[docs]@_summary_wrapper
def histogram(name, data, step=None, bins=None, walltime=None, max_bins=None):
"""Add histogram to summary.
Args:
name (str): Data identifier
data (Tensor | numpy.array | str/blobname): Values to build histogram
step (int): Global step value to record. None for using get_global_counter()
bins (int|str): Number of buckets or one of {‘tensorflow’,’auto’, ‘fd’, …}.
This determines how the bins are made. You can find other options in:
https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html
walltime (float): Optional override default walltime (time.time())
seconds after epoch of event
"""
if bins is None:
bins = _default_bins
_summary_writer_stack[-1].add_histogram(
name, data, step, bins=bins, walltime=walltime, max_bins=max_bins)
[docs]@_summary_wrapper
def embedding(name, data, step=None, class_labels=None, label_imgs=None):
"""Add embeddings to summary. The potentially high-dimensional embeddings
will be projected down to either 2D or 3D for visualization, with several
projection techniques to choose from in Tensorboard.
The visualized embeddings can be seen in the "PROJECTOR" page of Tensorboard.
Note: if this function is called multiple times, on the page there will be
multiple visualizations, each for every call.
Args:
name (str): data identifier
data (Tensor | numpy.array): a matrix of shape ``[N, D]``, where ``D``
is the dimensionality of the embedding.
step (int): global step value to record. None for using
``get_global_counter()``.
class_labels (list[str]): an optional list of class labels of length
``N`` can be provided, where each label corresponds to an embedding.
label_imgs (Tensor): an optional tensor of shape ``[N, C, H, W]``. Each
label img corresponds to an embedding. Use this if you want to
associate each embedding with an image for visualization.
"""
_summary_writer_stack[-1].add_embedding(
tag=name,
mat=data,
metadata=class_labels,
label_img=label_imgs,
global_step=step)
[docs]def should_record_summaries():
"""Whether summary should be recorded.
Returns:
bool: False means that all calls to scalar(), text(), histogram() etc
are not recorded.
"""
return (_summary_writer_stack[-1] and is_summary_enabled()
and _record_if_stack[-1]())
[docs]def get_global_counter():
"""Get the global counter
Returns:
the global int64 Tensor counter
"""
return _global_counter
[docs]def reset_global_counter():
"""Reset the global counter to zero."""
_global_counter.fill(0)
[docs]def increment_global_counter():
global _global_counter
_global_counter += 1
[docs]def set_global_counter(counter):
global _global_counter
_global_counter.fill(counter)
[docs]class record_if(object):
"""Context manager to set summary recording on or off according to `cond`."""
def __init__(self, cond: Callable):
"""Create the context manager.
Args:
cond (Callable): a function which returns whether summary should be
recorded.
"""
self._cond = cond
def __enter__(self):
_record_if_stack.append(self._cond)
def __exit__(self, type, value, traceback):
_record_if_stack.pop()
[docs]def create_summary_writer(summary_dir, flush_secs=10, max_queue=10):
"""Ceates a SummaryWriter that will write out events to the event file.
Args:
summary_dir (str) – Save directory location.
flush_secs (int) – How often, in seconds, to flush the pending events
and summaries to disk. Default is every 10 seconds.
max_queue (int) – Size of the queue for pending events and summaries
before one of the ‘add’ calls forces a flush to disk.
Default is ten items.
Returns:
SummaryWriter
"""
return SummaryWriter(
log_dir=summary_dir, flush_secs=flush_secs, max_queue=max_queue)
[docs]def set_default_writer(writer):
"""Set the default summary writer."""
_summary_writer_stack[0] = writer
[docs]def enable_summary(flag=True):
"""Enable summary.
Args:
flag (bool): True to enable, False to disable
"""
global _summary_enabled
_summary_enabled = flag
[docs]def disable_summary():
"""Disable summary."""
global _summary_enabled
_summary_enabled = False
[docs]def is_summary_enabled():
"""Return whether summary is enabled."""
return _summary_enabled
[docs]def should_summarize_output(flag=None):
"""Get or set summarize output flag.
Args:
flag (bool or None): when provided, sets the flag, otherwise, return
the stored _summarize_output flag.
Returns:
bool for getter or None for setter.
"""
global _summarize_output
if flag is None:
return _summarize_output and should_record_summaries()
else:
_summarize_output = bool(flag)
[docs]class push_summary_writer(object):
def __init__(self, writer):
self._writer = writer
def __enter__(self):
_summary_writer_stack.append(self._writer)
def __exit__(self, type, value, traceback):
_summary_writer_stack.pop()
[docs]def enter_summary_scope(method):
"""A decorator to run the wrapped method in a new summary scope.
The class the method belongs to must have attribute '_name' and it
will be used as the name of the summary scope.
Instead of using ``with alf.summary.scope(self._name):`` inside a class method,
we can use ``@alf.summary.enter_summary_scope`` to decorate the method to
have the benefit of cleaner code.
"""
@functools.wraps(method)
def wrapped(self, *args, **kwargs):
# The first argument to the method is going to be ``self``, i.e. the
# instance that the method belongs to.
assert hasattr(self,
'_name'), "self is expected to have attribute '_name'"
scope_name = _scope_stack[-1] + self._name + '/'
_scope_stack.append(scope_name)
ret = method(self, *args, **kwargs)
_scope_stack.pop()
return ret
return wrapped
[docs]class EnsureSummary(object):
"""Ensure summaries are generated in an infrequent code block.
Sometime, a code block runs infrequently or with different frequencey compared
to the summary_interval. This can lead to the problem that the summaries in
this code block are not generated or generated rarely. This class is a helper
to solve this problem.
.. code-block:: python
# initialization. For example, in __init__
self.ensure_summary = EnsureSummary()
# Add the following line at somewhere where it can be reached at very global step
self.ensure_summary.tick()
# Run the infrequent code block in the ensure_summary context:
with self.ensure_summary:
# the infrequent code block
"""
def __init__(self):
self._need_to_summarize = False
[docs] def tick(self):
if should_record_summaries():
self._need_to_summarize = True
def __enter__(self):
_record_if_stack.append(lambda: self._need_to_summarize)
def __exit__(self, type, value, traceback):
_record_if_stack.pop()
self._need_to_summarize = False