alf.nest#

alf.nest.nest#

Functions for handling nest.

assert_same_length(seq1, seq2)[source]#
assert_same_structure(nest1, nest2)[source]#

(C++) Asserts that two structures are nested in the same way.

assert_same_structure_up_to(shallow_nest, deep_nest)[source]#

(C++) Asserts that deep_nest has same structure as shallow_nest up the depths of shallow_nest. Every sub-nest of each of nests beyond the depth of the corresponding sub-nest in shallow_nest will be treated as a leaf.

Examples:

assert_same_structure_up_to(([2], None), ([1], [1, 2, 3]))
# success

assert_same_structure_up_to(([2], []), ([1], [1, 2, 3]))
# failure
Parameters
  • shallow_nest (nest) – a shallow nested structure.

  • deep_nest (nest) – a variable length of nested structures.

assert_same_type(value1, value2)[source]#
batch_nested_tensor(nested_tensor)[source]#

Unsqueeze a zero (batch) dimension for each entry in nested_tensor.

extract_any_leaf_from_nest(nest)[source]#

Extract an arbitrary leaf from a nest. Should be faster than doing flatten(nest)[0] because this function has short circuit.

Parameters

nest (nest) – a nested structure

Returns

A Tensor of there exists a leaf; otherwise None.

extract_fields_from_nest(nest)[source]#

Extract fields and the corresponding values from a nest if it’s either a namedtuple or dict.

Parameters

nest (nest) – a nested structure

Returns

an iterator that generates (field, value) pairs. The fields are sorted before being returned.

Return type

Iterable

Raises

AssertionError – if the nest is neither namedtuple nor dict.

fast_map_structure(func, *structure)[source]#

map_structure using pack_sequence_as().

fast_map_structure_flatten(func, structure, *flat_structure)[source]#

Applies func to entries of flat_structure and returns a packed structure according to structure.

find_field(nest, name, ignore_empty=True)[source]#

Find fields with given name.

Examples

nest = dict(a=1, b=dict(a=dict(a=2, b=3), b=2))
find_filed(nest, 'a')
# you would get [1, {"a": 2, "b": 3}]
Parameters
  • nest (nest) – a nest structure

  • name (str) – name of the field

  • ignore_empty (bool) – ignore the field if it is None or empty.

Returns

list

flatten(nest)[source]#

(C++) Returns a flat list from a given nested structure.

flatten_up_to(shallow_nest, nest)[source]#

(C++) Flatten nests up to the depths of shallow_nest. Every sub-nest of each of nests beyond the depth of the corresponding sub-nest in shallow_nest will be treated as a leaf that stops flattening downwards.

get_field(nested, field)[source]#

Get the field from nested.

field is a string separated by “.”. get_field(nested, "a.b") is equivalent to nested.a.b if nested is constructed using namedtuple or nests['a']['b'] if nested is contructed using dict. If nested is constructed using list or unnamed tuple, get_field(nested, "1.2") is equivalent to nested[1][2].

Parameters
  • nested (nest) – a nested structure

  • field (str) – indicate the path to the field with ‘.’ separating the field name at different level. None or ‘’ means the whole nest.

Returns

value of the field corresponding to field

Return type

nest

Raises

LookupError – if field cannot be found.

get_nest_batch_size(nest)[source]#

Get the batch size (dim=0) of a nest, assuming batch-major.

Parameters

nest (nest) – a nested structure

Returns

batch size

Return type

int

get_nest_shape(nest)[source]#

Get the shape of a nest leaf. It assumes that all nodes of the nest share the same shape. For efficiency we don’t do a check here.

Parameters

nest (nest) – a nested structure

Returns

Return type

torch.Size

get_nest_size(nest, dim)[source]#

Get the size of dimension dim from a nest. It assumes that all nodes of the nest share the same size.

Parameters
  • nest (nest) – a nested structure

  • dim (int) – the dimension to get the size for

Returns

size of dim

Return type

int

is_namedtuple(value)[source]#

Whether the value is a namedtuple instance.

Parameters

value (Object) –

Returns

True if the value is a namedtuple instance.

is_nested(value)[source]#

Returns true if the input is one of: list, unnamedtuple, dict, or namedtuple. Note that this definition is different from tf’s is_nested where all types that are collections.abc.Sequence are defined to be nested.

is_unnamedtuple(value)[source]#

Whether the value is an unnamedtuple instance.

map_structure(func, *nests)[source]#

(C++) Applies func to each entry in structure and returns a new structure.

map_structure_up_to(shallow_nest, func, *nests)[source]#

(C++) Applies a function to nests up to the depths of shallow_nest. Every sub-nest of each of nests beyond the depth of the corresponding sub-nest in shallow_nest will be treated as a leaf and input to func.

Examples (taken from tensorflow.nest.map_structure_up_to):

shallow_nest = [None, None]
inp_val = [[1], 2]
out = map_structure_up_to(shallow_nest, lambda x: 2 * x, inp_val)
# Output is: [[1, 1], 4]

ab_tuple = collections.namedtuple("ab_tuple", "a, b")
op_tuple = collections.namedtuple("op_tuple", "add, mul")
inp_val = ab_tuple(a=2, b=3)
inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul,
                            inp_val, inp_ops)
# Output is: ab_tuple(a=6, b=15)

data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
name_list = ['evens', ['odds', 'primes']]
out = map_structure_up_to(
    name_list,
    lambda name, sec: "first_{}_{}".format(len(sec), name),
    name_list, data_list)
# Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']]
Parameters
  • shallow_nest (nest) – a shallow nested structure.

  • func (Callable) – callable which will be applied to nests.

  • *nests (nest) – a variable length of nested structures.

Returns

a result nested structure that has the same depths with shallow_nest.

Return type

nest

map_structure_without_check(func, *nests)[source]#

(C++) Applies func to each entry in structure and returns a new structure. This function doesn’t do any check for efficiency.

nest_top_level(nested)[source]#

Given a nest, return its top-level structure, where the values are set to None.

Parameters

nested (Any) – a nested structure

pack_sequence_as(nest, flat_seq)[source]#

(C++) Returns a given flattened sequence packed into a given structure.

prune_nest_like(nest, slim_nest, value_to_match=None)[source]#

(C++) Prune a nested structure referring to another slim nest. Generally, for every corrsponding node, we only keep the fields that’re contained in slim_nest. In addition, if a field of slim_nest contains a value of value_to_match, then the corresponding field of nest will also be updated to this value.

Note

If a node is a list or unnamedtuple, then we require their lengths are equal.

Examples

x = dict(a=1, b=2)
y = dict(a=TensorSpec(()))
z = prune_nest_like(x, y) # z is dict(a=1)

y2 = dict(a=TensorSpec(()), b=())
z2 = prune_nest_like(x, y2, value_to_match=()) # z2 is dict(a=1, b=())
Parameters
  • nest (nest) – a nested structure

  • slim_nest (nest) – a slim nested structure. It’s required that at every node, its fields is a subset of those of nest.

  • value_to_match (nest) – a value that indicates the paired field of slim_nest should be updated in nest. Can be set to the default value of a namedtuple.

Returns

the pruned nest that has the same set of fields with slim_nest.

Return type

nest

py_assert_same_structure(nest1, nest2)[source]#

Asserts that two structures are nested in the same way.

py_flatten(nest)[source]#

Returns a flat list from a given nested structure.

py_flatten_up_to(shallow_nest, nest)[source]#

Flatten nests up to the depths of shallow_nest. Every sub-nest of each of nests beyond the depth of the corresponding sub-nest in shallow_nest will be treated as a leaf that stops flattening downwards.

py_map_structure(func, *nests)[source]#

Applies func to each entry in structure and returns a new structure.

py_map_structure_up_to(shallow_nest, func, *nests)[source]#

Applies a function to nests up to the depths of shallow_nest. Every sub-nest of each of nests beyond the depth of the corresponding sub-nest in shallow_nest will be treated as a leaf and input to func.

Examples (taken from tensorflow.nest.map_structure_up_to):

shallow_nest = [None, None]
inp_val = [[1], 2]
out = map_structure_up_to(shallow_nest, lambda x: 2 * x, inp_val)
# Output is: [[1, 1], 4]

ab_tuple = collections.namedtuple("ab_tuple", "a, b")
op_tuple = collections.namedtuple("op_tuple", "add, mul")
inp_val = ab_tuple(a=2, b=3)
inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul,
                            inp_val, inp_ops)
# Output is: ab_tuple(a=6, b=15)

data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
name_list = ['evens', ['odds', 'primes']]
out = map_structure_up_to(
    name_list,
    lambda name, sec: "first_{}_{}".format(len(sec), name),
    name_list, data_list)
# Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']]
Parameters
  • shallow_nest (nest) – a shallow nested structure.

  • func (Callable) – callable which will be applied to nests.

  • *nests (nest) – a variable length of nested structures.

Returns

a result nested structure that has the same depths with shallow_nest.

Return type

nest

py_map_structure_with_path(func, *nests)[source]#

Applies func to each entry in structure and returns a new structure. This function gives func access to one additional parameter as its first argument: the symbolic string of the path to the element currently supplied. List elements will be indexed by the ordinal position of the element in the list.

py_pack_sequence_as(nest, flat_seq)[source]#

Returns a given flattened sequence packed into a given structure.

py_prune_nest_like(nest, slim_nest, value_to_match=None)[source]#

Prune a nested structure referring to another slim nest. Generally, for every corrsponding node, we only keep the fields that’re contained in slim_nest. In addition, if a field of slim_nest contains a value of value_to_match, then the corresponding field of nest will also be updated to this value.

Note

If a node is a list or unnamedtuple, then we require their lengths are equal.

Examples

x = dict(a=1, b=2)
y = dict(a=TensorSpec(()))
z = prune_nest_like(x, y) # z is dict(a=1)

y2 = dict(a=TensorSpec(()), b=())
z2 = prune_nest_like(x, y2, value_to_match=()) # z2 is dict(a=1, b=())
Parameters
  • nest (nest) – a nested structure

  • slim_nest (nest) – a slim nested structure. It’s required that at every node, its fields is a subset of those of nest.

  • value_to_match (nest) – a value that indicates the paired field of slim_nest should be updated in nest. Can be set to the default value of a namedtuple.

Returns

the pruned nest that has the same set of fields with slim_nest.

Return type

nest

set_field(nested, field, new_value)[source]#

Set the field in nested to new_value.

field is a string separated by “.”. set_filed(nested, “a.b”, v) is equivalent to nested._replace(a=nested.a._replace(b=v)) if nested is constructed using namedtuple.

Parameters
  • nested (nest) – a nested structure

  • field (str) – indicate the path to the field with ‘.’ separating the field name at different level

  • new_value (any) – the new value for the field

Returns

a nest same as nested except the filed field replaced by

new_value

Return type

nest

sum_nest(nested)[source]#

Sum all elements in a nest.

Parameters

nested (Any) – a nested structure

transform_nest(nested, field, func)[source]#

Transform the node of a nested structure indicated by field using func.

This function can be used to update our namedtuple structure conveniently, comparing the following two methods:

info = info._replace(rl=info.rl._replace(sac=info.rl.sac * 0.5))

vs.

info = transform_nest(info, 'rl.sac', lambda x: x * 0.5)

The second method is usually shorter, more intuitive, and less error-prone when field is a long string.

Parameters
  • nested (nested Tensor) – the structure to be applied the transformation.

  • field (str) – If a string, it’s the field to be transformed, multi-level path denoted by “A.B.C”. Levels can also be integers (e.g., “0.2”), in which case the nest is expected to be tuples or lists at those levels. If None, then the root object is transformed.

  • func (Callable) – transform func, the function will be called as func(nested) and should return a new nest.

Returns

transformed nest

transform_nests(nests, field, func)[source]#
Transform the node of each of the nest in nests indicated by field

using func.

This function can be used to transform multiple nests, and perform

transformations with inter-nest interactions.

res1, res2 = transform_nests([nest1, nest2], 'a.b',
                    lambda x: (x[0] * x[1], x[0] + x[1]))

where x[0] denotes the value from nest1 and x[1] is from nest2.

Parameters
  • nests ([nested Tensor]) – the structure to be applied the transformation.

  • field (str) – If a string, it’s the field to be transformed, multi-level path denoted by “A.B.C”. If None, then the root object is transformed.

  • func (Callable) – transform func, the function will be called as func(nested) and should return a new nest.

Returns

list of transformed nests, with its length the same as the input nests

transpose(nested, shallow_nest=None, new_shallow_nest=None)[source]#

Given a nest A and its shallow nest a, assuming that each child of a has the same nest structure B, this function returns a new nest whose shallow nest b is a shallow nest of B, and each child of b has a shallow nest a.

An illustrative graph shows the transpose operation:

A = a-B = a-b-C (transpose->) b-a-C

where C is every (same) child of b (could be empty).

Note

ALF defines the “shallow nest” of a nest as the subtree that starts from the nest root and contains at least all the direct children of the nest. It can optionally contain more descendants of the nest.

For example,

x = [(0, 1), (2, 3), (4, 5)]
y = transpose(x, shallow_nest=[None, None, None])
# y will be ``([0, 2, 4], [1, 3, 5])``
y1 = transpose(x)
# y1 will be the same with y

x = NTuple(a=dict(x=3, y=1), b=[dict(x=5, y=10)])
shallow_nest = NTuple(a=None, b=[False])
y = transpose(x, shallow_nest)
# y will be ``dict(x=NTuple(a=3, b=[5]), y=NTuple(a=1, b=[10]))``

x = NTuple(a=dict(x=3, y=dict(n=1, m=2)),
           b=dict(x=5, y=dict(n=1, m=3)))
transposed_nest1 = nest.transpose(x)
self.assertEqual(transposed_nest1,
                 dict(x=NTuple(a=3, b=5), y=NTuple(a=dict(n=1, m=2),
                                                   b=dict(n=1, m=3))))
Parameters
  • nested (Any) – a nested structure

  • shallow_nest (Optional[Any]) – a nested structure indicating the first “axis” for the transpose. If None, then nest_top_level(nested) will be used.

  • new_shallow_nest (Optional[Any]) – a nested structure indicating the second “axis” for the transpose. Note that this shallow nest is w.r.t. each child B. If not provided, then nest_top_level(B) will be used.

Returns

a transposed nested structure

Return type

nested

unbatch_nested_tensor(nested_tensor)[source]#

Squeeze the first (batch) dimension of each entry in nested_tensor.

alf.nest.utils#

Some nest utils functions.

class NestCombiner(name, batch_dims=1)[source]#

Bases: abc.ABC, torch.nn.modules.module.Module

A base class for combining all elements in a nested structure.

Parameters
  • name (str) – name of the combiner

  • batch_dims (int) – number of batch dims (default 1). This argument is only necessary for combiners that are not batch-dim invariant (combined results depending on the definition of batch dims, e.g., outer product).

training: bool#
class NestConcat(nest_mask=None, dim=- 1, name='NestConcat')[source]#

Bases: alf.nest.utils.NestCombiner

A combiner for selecting from the tensors in a nest and then concatenating them along a specified axis. If nest_mask is None, then all the tensors from the nest will be selected. It assumes that all the selected tensors have the same tensor spec. Can be used as a preprocessing combiner of a network.

Note that batch dimension is not considered for concat. This means that dim=0 means the first dimension after batch dimension.

Parameters
  • nest_mask (nest|None) – nest structured mask indicating which of the tensors in the nest to be selected or not, indicated by a value of True/False (1/0). Note that the structure of the mask should be the same as the nest of data to apply this operator on. If is None, then all the tensors from the nest will be selected.

  • dim (int) – the dim along which the tensors are concatenated

  • name (str) –

make_parallel(n)[source]#

Create a NestConcat layer to handle parallel batch.

It is assumed that a parallel batch has shape [B, n, …] and both the batch dimension and replica dimension are not considered for concat.

Parameters

n (int) – the number of replicas.

Returns

a NestConcat layer to handle parallel batch.

training: bool#
class NestMultiply(activation=None, name='NestMultiply')[source]#

Bases: alf.nest.utils.NestCombiner

Element-wise multiply all tensors in a nest. It assumes that all tensors have the same shape. Can be used as a preprocessing combiner of a network.

Parameters
  • activation (Callable) – optional activation function applied after the multiplication.

  • name (str) –

make_parallel(n)[source]#
training: bool#
class NestOuterProduct(activation=None, batch_dims=1, padding=False, name='NestOuterProduct')[source]#

Bases: alf.nest.utils.NestCombiner

Perform outer-product operations across a nested structure. Can be used as a preprocessing combiner of a network.

Sometimes combining tensors using outer product might be more expressive than concatenating, e.g., when one tensor is one-hot. See the discussions in

"STOCHASTIC NEURAL NETWORKS FOR HIERARCHICAL REINFORCEMENT LEARNING",
Florensa, et al., ICLR 2017, https://arxiv.org/pdf/1704.03012.pdf.

In this implementation, we also support padding 1s to the tensors before doing the outer product, essentially combining outer product and concatenation together in one combiner.

Warning

Due to outer product, this combiner might result in a very long output vector. Make sure to do the calculation before using it.

Parameters
  • activation (Optional[Callable]) – optional activation function applied after the outer product.

  • batch_dims (int) – number of batch dims. Default to 1. If the total input dim is N, then the last N-batch_dims will be flattened for outer product.

  • padding (bool) – if True, each tensor will be padded by 1 before performing outer product. When this flag is enabled, essentially it has the effect of concatenation of all tensors in the output tensor.

  • name (str) – name of the combiner

make_parallel(n)[source]#
training: bool#
class NestSum(average=False, activation=None, name='NestSum')[source]#

Bases: alf.nest.utils.NestCombiner

Add all tensors in a nest together. It assumes that all tensors have the same tensor shape. Can be used as a preprocessing combiner of a network.

Parameters
  • average (bool) – If True, the tensors are averaged instead of summed.

  • activation (Callable) – activation function.

  • name (str) –

make_parallel(n)[source]#
training: bool#
convert_device(nests, device=None)[source]#
Convert the device of the tensors in nests to the specified

or to the default device.

Parameters
  • nests (nested Tensors) – Nested list/tuple/dict of Tensors.

  • device (None|str) – the target device, should either be cuda or cpu. If None, then the default device will be used as the target device.

Returns

Nested list/tuple/dict of Tensors after device

conversion.

Return type

nests (nested Tensors)

Raises

NotImplementedError if the target device is not one of – None, cpu or cuda when cuda is available, or AssertionError if target device is cuda but cuda is unavailable.

get_nested_field(nested, nest_fields)[source]#

Get nested fields from a nest.

Example

x = get_nested_field(nest, (‘a.b’, ‘c’)) y = (get_field(nest, ‘a.b’)), get_field(nest, ‘c’)) # y and x are same

Parameters
  • nested (nest) – a nested structure

  • nest_fields (nested str) – nested strings. Each string indicates a path to retrieve the value from nest

Returns

a nest with same structure as nest_fields.

get_outer_rank(tensors, specs)[source]#

Compares tensors to specs to determine the number of batch dimensions.

For each tensor, it checks the dimensions with respect to specs and returns the number of batch dimensions if all nested tensors and specs agree with each other.

Parameters
  • tensors (nested Tensors) – Nested list/tuple/dict of Tensors.

  • specs (nested TensorSpecs) – Nested list/tuple/dict of TensorSpecs, describing the shape of unbatched tensors.

Returns

The number of outer dimensions for all tensors (zero if all are unbatched or empty).

Return type

int

Raises

AssertionError – If the shape of Tensors are not compatible with specs, a mix of batched and unbatched tensors are provided, or the tensors are batched but have an incorrect number of outer dims.

grad(nested, objective, retain_graph=False)[source]#

Compute the gradients of an objective w.r.t each variable in nested. It will simply call torch.autograd.grad after flattening the nest, and then pack the flat list back to a structure like nested.

Parameters
  • nested (nest) – a nest of variables that require grads.

  • objective (Tensor) – a tensor whose gradients will be computed.

  • retain_graph (bool) – if True, after autograd the computational graph won’t be freed

make_nested_module(nested, ignore_non_module_element=True)[source]#

Convert a nest of modules to nn.Module using nn.ModuleList or nn.ModuleDict.

The reason to use this function is that nest of Modules will not be trained or checkpointed. We need to use nn.ModuleList or nn.ModuleDict to hold the individual modules in the nest.

Parameters
  • nested (nested nn.Module) – a nest of nn.Module

  • ignore_non_module_element (bool) – If True, will ignore the non-module element and replace them with None. If False, will raise error if there are any non-module elements.

Returns

nn.Module

stack_nests(nests, dim=0)[source]#

Stack tensors to a sequence.

All the nest should have same structure and shape. In the resulted nest, each tensor has shape of \([T,...]\) and is the concat of all the corresponding tensors in nests.

Parameters
  • nests (list[nest]) – list of nests with same structure and shape.

  • dim (int) – dimension to insert. Has to be between 0 and the number of dimensions of concatenated tensors (inclusive)

Returns

a nest with same structure as nests[0].

zeros_like(nested)[source]#

Create a new nest with all zeros like the reference nested.

Parameters

nested (nested Tensor) – a nested structure

Returns

a nest with all zeros

Return type

nested Tensor