alf.nest#
alf.nest.nest#
Functions for handling nest.
- assert_same_structure(nest1, nest2)[source]#
(C++) Asserts that two structures are nested in the same way.
- assert_same_structure_up_to(shallow_nest, deep_nest)[source]#
(C++) Asserts that
deep_nesthas same structure asshallow_nestup the depths ofshallow_nest. Every sub-nest of each ofnestsbeyond the depth of the corresponding sub-nest inshallow_nestwill be treated as a leaf.Examples:
assert_same_structure_up_to(([2], None), ([1], [1, 2, 3])) # success assert_same_structure_up_to(([2], []), ([1], [1, 2, 3])) # failure
- Parameters
shallow_nest (nest) – a shallow nested structure.
deep_nest (nest) – a variable length of nested structures.
- batch_nested_tensor(nested_tensor)[source]#
Unsqueeze a zero (batch) dimension for each entry in
nested_tensor.
- extract_any_leaf_from_nest(nest)[source]#
Extract an arbitrary leaf from a nest. Should be faster than doing
flatten(nest)[0]because this function has short circuit.- Parameters
nest (nest) – a nested structure
- Returns
A
Tensorof there exists a leaf; otherwiseNone.
- extract_fields_from_nest(nest)[source]#
Extract fields and the corresponding values from a nest if it’s either a
namedtupleordict.- Parameters
nest (nest) – a nested structure
- Returns
an iterator that generates
(field, value)pairs. The fields are sorted before being returned.- Return type
Iterable
- Raises
AssertionError – if the nest is neither
namedtuplenordict.
- fast_map_structure_flatten(func, structure, *flat_structure)[source]#
Applies func to entries of
flat_structureand returns a packed structure according tostructure.
- find_field(nest, name, ignore_empty=True)[source]#
Find fields with given name.
Examples
nest = dict(a=1, b=dict(a=dict(a=2, b=3), b=2)) find_filed(nest, 'a') # you would get [1, {"a": 2, "b": 3}]
- Parameters
nest (nest) – a nest structure
name (str) – name of the field
ignore_empty (bool) – ignore the field if it is None or empty.
- Returns
list
- flatten_up_to(shallow_nest, nest)[source]#
(C++) Flatten
nestsup to the depths ofshallow_nest. Every sub-nest of each ofnestsbeyond the depth of the corresponding sub-nest inshallow_nestwill be treated as a leaf that stops flattening downwards.
- get_field(nested, field)[source]#
Get the field from nested.
fieldis a string separated by “.”.get_field(nested, "a.b")is equivalent tonested.a.bifnestedis constructed using namedtuple ornests['a']['b']if nested is contructed using dict. If nested is constructed using list or unnamed tuple,get_field(nested, "1.2")is equivalent tonested[1][2].- Parameters
nested (nest) – a nested structure
field (str) – indicate the path to the field with ‘.’ separating the field name at different level.
Noneor ‘’ means the whole nest.
- Returns
value of the field corresponding to
field- Return type
nest
- Raises
LookupError – if field cannot be found.
- get_nest_batch_size(nest)[source]#
Get the batch size (dim=0) of a nest, assuming batch-major.
- Parameters
nest (nest) – a nested structure
- Returns
batch size
- Return type
int
- get_nest_shape(nest)[source]#
Get the shape of a nest leaf. It assumes that all nodes of the nest share the same shape. For efficiency we don’t do a check here.
- Parameters
nest (nest) – a nested structure
- Returns
- Return type
torch.Size
- get_nest_size(nest, dim)[source]#
Get the size of dimension
dimfrom a nest. It assumes that all nodes of the nest share the same size.- Parameters
nest (nest) – a nested structure
dim (int) – the dimension to get the size for
- Returns
size of
dim- Return type
int
- is_namedtuple(value)[source]#
Whether the value is a namedtuple instance.
- Parameters
value (Object) –
- Returns
Trueif the value is a namedtuple instance.
- is_nested(value)[source]#
Returns true if the input is one of:
list,unnamedtuple,dict, ornamedtuple. Note that this definition is different from tf’s is_nested where all types that arecollections.abc.Sequenceare defined to be nested.
- map_structure(func, *nests)[source]#
(C++) Applies func to each entry in structure and returns a new structure.
- map_structure_up_to(shallow_nest, func, *nests)[source]#
(C++) Applies a function to
nestsup to the depths ofshallow_nest. Every sub-nest of each ofnestsbeyond the depth of the corresponding sub-nest inshallow_nestwill be treated as a leaf and input tofunc.Examples (taken from
tensorflow.nest.map_structure_up_to):shallow_nest = [None, None] inp_val = [[1], 2] out = map_structure_up_to(shallow_nest, lambda x: 2 * x, inp_val) # Output is: [[1, 1], 4] ab_tuple = collections.namedtuple("ab_tuple", "a, b") op_tuple = collections.namedtuple("op_tuple", "add, mul") inp_val = ab_tuple(a=2, b=3) inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) # Output is: ab_tuple(a=6, b=15) data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] name_list = ['evens', ['odds', 'primes']] out = map_structure_up_to( name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), name_list, data_list) # Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']]
- Parameters
shallow_nest (nest) – a shallow nested structure.
func (Callable) – callable which will be applied to
nests.*nests (nest) – a variable length of nested structures.
- Returns
a result nested structure that has the same depths with
shallow_nest.- Return type
nest
- map_structure_without_check(func, *nests)[source]#
(C++) Applies func to each entry in structure and returns a new structure. This function doesn’t do any check for efficiency.
- nest_top_level(nested)[source]#
Given a nest, return its top-level structure, where the values are set to
None.- Parameters
nested (
Any) – a nested structure
- pack_sequence_as(nest, flat_seq)[source]#
(C++) Returns a given flattened sequence packed into a given structure.
- prune_nest_like(nest, slim_nest, value_to_match=None)[source]#
(C++) Prune a nested structure referring to another slim nest. Generally, for every corrsponding node, we only keep the fields that’re contained in
slim_nest. In addition, if a field ofslim_nestcontains a value ofvalue_to_match, then the corresponding field ofnestwill also be updated to this value.Note
If a node is a
listorunnamedtuple, then we require their lengths are equal.Examples
x = dict(a=1, b=2) y = dict(a=TensorSpec(())) z = prune_nest_like(x, y) # z is dict(a=1) y2 = dict(a=TensorSpec(()), b=()) z2 = prune_nest_like(x, y2, value_to_match=()) # z2 is dict(a=1, b=())
- Parameters
nest (nest) – a nested structure
slim_nest (nest) – a slim nested structure. It’s required that at every node, its fields is a subset of those of
nest.value_to_match (nest) – a value that indicates the paired field of
slim_nestshould be updated innest. Can be set to the default value of anamedtuple.
- Returns
the pruned nest that has the same set of fields with
slim_nest.- Return type
nest
- py_assert_same_structure(nest1, nest2)[source]#
Asserts that two structures are nested in the same way.
- py_flatten_up_to(shallow_nest, nest)[source]#
Flatten
nestsup to the depths ofshallow_nest. Every sub-nest of each ofnestsbeyond the depth of the corresponding sub-nest inshallow_nestwill be treated as a leaf that stops flattening downwards.
- py_map_structure(func, *nests)[source]#
Applies func to each entry in structure and returns a new structure.
- py_map_structure_up_to(shallow_nest, func, *nests)[source]#
Applies a function to
nestsup to the depths ofshallow_nest. Every sub-nest of each ofnestsbeyond the depth of the corresponding sub-nest inshallow_nestwill be treated as a leaf and input tofunc.Examples (taken from
tensorflow.nest.map_structure_up_to):shallow_nest = [None, None] inp_val = [[1], 2] out = map_structure_up_to(shallow_nest, lambda x: 2 * x, inp_val) # Output is: [[1, 1], 4] ab_tuple = collections.namedtuple("ab_tuple", "a, b") op_tuple = collections.namedtuple("op_tuple", "add, mul") inp_val = ab_tuple(a=2, b=3) inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) # Output is: ab_tuple(a=6, b=15) data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] name_list = ['evens', ['odds', 'primes']] out = map_structure_up_to( name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), name_list, data_list) # Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']]
- Parameters
shallow_nest (nest) – a shallow nested structure.
func (Callable) – callable which will be applied to
nests.*nests (nest) – a variable length of nested structures.
- Returns
a result nested structure that has the same depths with
shallow_nest.- Return type
nest
- py_map_structure_with_path(func, *nests)[source]#
Applies func to each entry in structure and returns a new structure. This function gives func access to one additional parameter as its first argument: the symbolic string of the path to the element currently supplied. List elements will be indexed by the ordinal position of the element in the list.
- py_pack_sequence_as(nest, flat_seq)[source]#
Returns a given flattened sequence packed into a given structure.
- py_prune_nest_like(nest, slim_nest, value_to_match=None)[source]#
Prune a nested structure referring to another slim nest. Generally, for every corrsponding node, we only keep the fields that’re contained in
slim_nest. In addition, if a field ofslim_nestcontains a value ofvalue_to_match, then the corresponding field ofnestwill also be updated to this value.Note
If a node is a
listorunnamedtuple, then we require their lengths are equal.Examples
x = dict(a=1, b=2) y = dict(a=TensorSpec(())) z = prune_nest_like(x, y) # z is dict(a=1) y2 = dict(a=TensorSpec(()), b=()) z2 = prune_nest_like(x, y2, value_to_match=()) # z2 is dict(a=1, b=())
- Parameters
nest (nest) – a nested structure
slim_nest (nest) – a slim nested structure. It’s required that at every node, its fields is a subset of those of
nest.value_to_match (nest) – a value that indicates the paired field of
slim_nestshould be updated innest. Can be set to the default value of anamedtuple.
- Returns
the pruned nest that has the same set of fields with
slim_nest.- Return type
nest
- set_field(nested, field, new_value)[source]#
Set the field in nested to
new_value.field is a string separated by “.”. set_filed(nested, “a.b”, v) is equivalent to
nested._replace(a=nested.a._replace(b=v))if nested is constructed using namedtuple.- Parameters
nested (nest) – a nested structure
field (str) – indicate the path to the field with ‘.’ separating the field name at different level
new_value (any) – the new value for the field
- Returns
- a nest same as
nestedexcept the filedfieldreplaced by new_value
- a nest same as
- Return type
nest
- transform_nest(nested, field, func)[source]#
Transform the node of a nested structure indicated by
fieldusingfunc.This function can be used to update our
namedtuplestructure conveniently, comparing the following two methods:info = info._replace(rl=info.rl._replace(sac=info.rl.sac * 0.5))
vs.
info = transform_nest(info, 'rl.sac', lambda x: x * 0.5)
The second method is usually shorter, more intuitive, and less error-prone when
fieldis a long string.- Parameters
nested (nested Tensor) – the structure to be applied the transformation.
field (str) – If a string, it’s the field to be transformed, multi-level path denoted by “A.B.C”. Levels can also be integers (e.g., “0.2”), in which case the nest is expected to be tuples or lists at those levels. If
None, then the root object is transformed.func (Callable) – transform func, the function will be called as
func(nested)and should return a new nest.
- Returns
transformed nest
- transform_nests(nests, field, func)[source]#
- Transform the node of each of the nest in nests indicated by
field using
func.- This function can be used to transform multiple nests, and perform
transformations with inter-nest interactions.
res1, res2 = transform_nests([nest1, nest2], 'a.b', lambda x: (x[0] * x[1], x[0] + x[1]))
where
x[0]denotes the value fromnest1andx[1]is fromnest2.
- Parameters
nests ([nested Tensor]) – the structure to be applied the transformation.
field (str) – If a string, it’s the field to be transformed, multi-level path denoted by “A.B.C”. If
None, then the root object is transformed.func (Callable) – transform func, the function will be called as
func(nested)and should return a new nest.
- Returns
list of transformed nests, with its length the same as the input nests
- Transform the node of each of the nest in nests indicated by
- transpose(nested, shallow_nest=None, new_shallow_nest=None)[source]#
Given a nest
Aand its shallow nesta, assuming that each child ofahas the same nest structureB, this function returns a new nest whose shallow nestbis a shallow nest ofB, and each child ofbhas a shallow nesta.An illustrative graph shows the transpose operation:
A = a-B = a-b-C (transpose->) b-a-C
where
Cis every (same) child ofb(could be empty).Note
ALF defines the “shallow nest” of a nest as the subtree that starts from the nest root and contains at least all the direct children of the nest. It can optionally contain more descendants of the nest.
For example,
x = [(0, 1), (2, 3), (4, 5)] y = transpose(x, shallow_nest=[None, None, None]) # y will be ``([0, 2, 4], [1, 3, 5])`` y1 = transpose(x) # y1 will be the same with y x = NTuple(a=dict(x=3, y=1), b=[dict(x=5, y=10)]) shallow_nest = NTuple(a=None, b=[False]) y = transpose(x, shallow_nest) # y will be ``dict(x=NTuple(a=3, b=[5]), y=NTuple(a=1, b=[10]))`` x = NTuple(a=dict(x=3, y=dict(n=1, m=2)), b=dict(x=5, y=dict(n=1, m=3))) transposed_nest1 = nest.transpose(x) self.assertEqual(transposed_nest1, dict(x=NTuple(a=3, b=5), y=NTuple(a=dict(n=1, m=2), b=dict(n=1, m=3))))
- Parameters
nested (
Any) – a nested structureshallow_nest (
Optional[Any]) – a nested structure indicating the first “axis” for the transpose. If None, thennest_top_level(nested)will be used.new_shallow_nest (
Optional[Any]) – a nested structure indicating the second “axis” for the transpose. Note that this shallow nest is w.r.t. each childB. If not provided, thennest_top_level(B)will be used.
- Returns
a transposed nested structure
- Return type
nested
alf.nest.utils#
Some nest utils functions.
- class NestCombiner(name, batch_dims=1)[source]#
Bases:
abc.ABC,torch.nn.modules.module.ModuleA base class for combining all elements in a nested structure.
- Parameters
name (
str) – name of the combinerbatch_dims (
int) – number of batch dims (default 1). This argument is only necessary for combiners that are not batch-dim invariant (combined results depending on the definition of batch dims, e.g., outer product).
- training: bool#
- class NestConcat(nest_mask=None, dim=- 1, name='NestConcat')[source]#
Bases:
alf.nest.utils.NestCombinerA combiner for selecting from the tensors in a nest and then concatenating them along a specified axis. If nest_mask is None, then all the tensors from the nest will be selected. It assumes that all the selected tensors have the same tensor spec. Can be used as a preprocessing combiner of a network.
Note that batch dimension is not considered for concat. This means that dim=0 means the first dimension after batch dimension.
- Parameters
nest_mask (nest|None) – nest structured mask indicating which of the tensors in the nest to be selected or not, indicated by a value of True/False (1/0). Note that the structure of the mask should be the same as the nest of data to apply this operator on. If is None, then all the tensors from the nest will be selected.
dim (int) – the dim along which the tensors are concatenated
name (str) –
- make_parallel(n)[source]#
Create a
NestConcatlayer to handle parallel batch.It is assumed that a parallel batch has shape [B, n, …] and both the batch dimension and replica dimension are not considered for concat.
- Parameters
n (int) – the number of replicas.
- Returns
a
NestConcatlayer to handle parallel batch.
- training: bool#
- class NestMultiply(activation=None, name='NestMultiply')[source]#
Bases:
alf.nest.utils.NestCombinerElement-wise multiply all tensors in a nest. It assumes that all tensors have the same shape. Can be used as a preprocessing combiner of a network.
- Parameters
activation (Callable) – optional activation function applied after the multiplication.
name (str) –
- training: bool#
- class NestOuterProduct(activation=None, batch_dims=1, padding=False, name='NestOuterProduct')[source]#
Bases:
alf.nest.utils.NestCombinerPerform outer-product operations across a nested structure. Can be used as a preprocessing combiner of a network.
Sometimes combining tensors using outer product might be more expressive than concatenating, e.g., when one tensor is one-hot. See the discussions in
"STOCHASTIC NEURAL NETWORKS FOR HIERARCHICAL REINFORCEMENT LEARNING", Florensa, et al., ICLR 2017, https://arxiv.org/pdf/1704.03012.pdf.
In this implementation, we also support padding 1s to the tensors before doing the outer product, essentially combining outer product and concatenation together in one combiner.
Warning
Due to outer product, this combiner might result in a very long output vector. Make sure to do the calculation before using it.
- Parameters
activation (
Optional[Callable]) – optional activation function applied after the outer product.batch_dims (
int) – number of batch dims. Default to 1. If the total input dim isN, then the lastN-batch_dimswill be flattened for outer product.padding (
bool) – if True, each tensor will be padded by 1 before performing outer product. When this flag is enabled, essentially it has the effect of concatenation of all tensors in the output tensor.name (
str) – name of the combiner
- training: bool#
- class NestSum(average=False, activation=None, name='NestSum')[source]#
Bases:
alf.nest.utils.NestCombinerAdd all tensors in a nest together. It assumes that all tensors have the same tensor shape. Can be used as a preprocessing combiner of a network.
- Parameters
average (bool) – If True, the tensors are averaged instead of summed.
activation (Callable) – activation function.
name (str) –
- training: bool#
- convert_device(nests, device=None)[source]#
- Convert the device of the tensors in nests to the specified
or to the default device.
- Parameters
nests (nested Tensors) – Nested list/tuple/dict of Tensors.
device (None|str) – the target device, should either be cuda or cpu. If None, then the default device will be used as the target device.
- Returns
- Nested list/tuple/dict of Tensors after device
conversion.
- Return type
nests (nested Tensors)
- Raises
NotImplementedError if the target device is not one of – None, cpu or cuda when cuda is available, or AssertionError if target device is cuda but cuda is unavailable.
- get_nested_field(nested, nest_fields)[source]#
Get nested fields from a nest.
Example
x = get_nested_field(nest, (‘a.b’, ‘c’)) y = (get_field(nest, ‘a.b’)), get_field(nest, ‘c’)) # y and x are same
- Parameters
nested (nest) – a nested structure
nest_fields (nested str) – nested strings. Each string indicates a path to retrieve the value from
nest
- Returns
a nest with same structure as
nest_fields.
- get_outer_rank(tensors, specs)[source]#
Compares tensors to specs to determine the number of batch dimensions.
For each tensor, it checks the dimensions with respect to specs and returns the number of batch dimensions if all nested tensors and specs agree with each other.
- Parameters
tensors (nested Tensors) – Nested list/tuple/dict of Tensors.
specs (nested TensorSpecs) – Nested list/tuple/dict of TensorSpecs, describing the shape of unbatched tensors.
- Returns
The number of outer dimensions for all tensors (zero if all are unbatched or empty).
- Return type
int
- Raises
AssertionError – If the shape of Tensors are not compatible with specs, a mix of batched and unbatched tensors are provided, or the tensors are batched but have an incorrect number of outer dims.
- grad(nested, objective, retain_graph=False)[source]#
Compute the gradients of an
objectivew.r.t each variable innested. It will simply calltorch.autograd.gradafter flattening the nest, and then pack the flat list back to a structure likenested.- Parameters
nested (nest) – a nest of variables that require grads.
objective (Tensor) – a tensor whose gradients will be computed.
retain_graph (bool) – if True, after autograd the computational graph won’t be freed
- make_nested_module(nested, ignore_non_module_element=True)[source]#
Convert a nest of modules to nn.Module using nn.ModuleList or nn.ModuleDict.
The reason to use this function is that nest of Modules will not be trained or checkpointed. We need to use nn.ModuleList or nn.ModuleDict to hold the individual modules in the nest.
- Parameters
nested (nested nn.Module) – a nest of nn.Module
ignore_non_module_element (bool) – If True, will ignore the non-module element and replace them with None. If False, will raise error if there are any non-module elements.
- Returns
nn.Module
- stack_nests(nests, dim=0)[source]#
Stack tensors to a sequence.
All the nest should have same structure and shape. In the resulted nest, each tensor has shape of \([T,...]\) and is the concat of all the corresponding tensors in nests.
- Parameters
nests (list[nest]) – list of nests with same structure and shape.
dim (int) – dimension to insert. Has to be between 0 and the number of dimensions of concatenated tensors (inclusive)
- Returns
a nest with same structure as
nests[0].