# Copyright (c) 2021 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Various Network containers."""
import copy
import torch.nn as nn
from typing import Callable
import alf
from alf.nest import (flatten, flatten_up_to, get_field, map_structure,
map_structure_up_to, pack_sequence_as)
from alf.nest.utils import get_nested_field
from alf.utils.spec_utils import is_same_spec
from .network import Network, get_input_tensor_spec, wrap_as_network
from alf.layers import make_parallel_spec
[docs]def Sequential(*modules,
output='',
input_tensor_spec=None,
name="Sequential",
**named_modules):
"""Network composed of a sequence of torch.nn.Module or alf.nn.Network.
All the modules provided through ``modules`` and ``named_modules`` are calculated
sequentially in the same order as they appear in the call to ``Sequential``.
Typically, each module takes the result of the previous module as its input
(or the input to the Sequential if it is the first module), and the result of
the last module is the output of the Sequential. But we also allow more
flexibilities as shown in example 2.
Example 1:
.. code-block:: python
net = Sequential(module1, module2)
y, new_state = net(x, state)
is equivalent to the following:
.. code-block:: python
z, new_state1 = module1(x, state[0])
y, new_state2 = module2(z, state[1])
new_state = (new_state1, new_state2)
Example 2:
.. code-block:: python
net = Sequential(
module1, a=module2, b=(('input', 'a'), module3), output=('a', 'b'))
output, new_state = net(input, state)
is equivalent to the following:
.. code-block:: python
_, new_state1 = module1(input, state[0])
a, new_state2 = module2(_, state[1])
b, new_state3 = module3((input, a), state[2])
new_state = (new_state1, new_state2, new_state3)
output = (a, b)
Args:
modules (Callable | (nested str, Callable)):
The ``Callable`` can be a ``torch.nn.Module``, ``alf.nn.Network``
or plain ``Callable``. Optionally, their inputs can be specified
by the first element of the tuple. If input is not provided, it is
assumed to be the result of the previous module (or input to this
``Sequential`` for the first module). If input is provided, it
should be a nested str. It will be used to retrieve results from
the dictionary of the current ``named_results``. For modules
specified by ``modules``, because no ``named_modules`` has been
invoked, ``named_results`` is ``{'input': input}``.
named_modules (Callable | (nested str, Callable)):
The ``Callable`` can be a ``torch.nn.Module``, ``alf.nn.Network``
or plain ``Callable``. Optionally, their inputs can be specified
by the first element of the tuple. If input is not provided, it is
assumed to be the result of the previous module (or input to this
``Sequential`` for the first module). If input is provided, it
should be a nested str. It will be used to retrieve results from
the dictionary of the current ``named_results``. ``named_results``
is updated once the result of a named module is calculated.
output (nested str): if not provided, the result from the last module
will be used as output. Otherwise, it will be used to retrieve
results from ``named_results`` after the results of all modules
have been calculated.
input_tensor_spec (TensorSpec): the tensor spec of the input. It must
be specified if it cannot be inferred from ``modules[0]``.
name (str):
"""
# The reason that we use a wrapper function for _Sequential is that Network
# does not allow *args for __init__() (see _NetworkMeta.__new__()). And we
# want to use *modules here to make the interface consistent with
# torch.nn.Sequential and alf.layers.Sequential to avoid confusion.
return _Sequential(
modules,
named_modules,
output=output,
input_tensor_spec=input_tensor_spec,
name=name)
class _Sequential(Network):
def __init__(self,
elements=(),
element_dict={},
output='',
input_tensor_spec=None,
name='Sequential'):
state_spec = []
modules = []
inputs = []
outputs = []
simple = True
named_elements = list(zip([''] * len(elements), elements)) + list(
element_dict.items())
is_nested_str = lambda s: all(
map(lambda x: type(x) == str, flatten(s)))
for i, (out, element) in enumerate(named_elements):
input = ''
if isinstance(element, tuple) and len(element) == 2:
input, module = element
else:
module = element
if not (isinstance(module, Callable) and is_nested_str(input)):
raise ValueError(
"Argument %s is not in the form of Callable "
"or (nested str, Callable): %s" % (out or str(i), element))
if isinstance(module, type):
raise ValueError(
"module should not be a type. Did you forget "
"to include '()' after it to contruct the layer? module=%s"
% str(module))
if isinstance(module, Network):
state_spec.append(module.state_spec)
else:
state_spec.append(())
inputs.append(input)
outputs.append(out)
modules.append(module)
if out or input:
simple = False
if output:
simple = False
assert is_nested_str(output), (
"output should be a nested str: %s" % output)
if len(flatten(state_spec)) == 0:
state_spec = ()
if input_tensor_spec is None and not inputs[0]:
input_tensor_spec = get_input_tensor_spec(modules[0])
assert input_tensor_spec is not None, (
"input_tensor_spec needs to be provided")
super().__init__(input_tensor_spec, state_spec=state_spec, name=name)
self._networks = modules
# pytorch nn.Moddule needs to use ModuleList to keep track of parameters
self._nets = nn.ModuleList(
filter(lambda m: isinstance(m, nn.Module), modules))
if simple:
self.forward = self._forward_simple
else:
self.forward = self._forward_complex
self._output = output
self._inputs = inputs
self._outputs = outputs
def _forward_simple(self, input, state=()):
x = input
if self._state_spec == ():
for net in self._networks:
if isinstance(net, Network):
x = net(x)[0]
else:
x = net(x)
return x, state
else:
new_state = [()] * len(self._networks)
for i, net in enumerate(self._networks):
if isinstance(net, Network):
x, new_state[i] = net(x, state[i])
else:
x = net(x)
return x, new_state
def _forward_complex(self, input, state=()):
x = input
var_dict = {'input': x}
if self._state_spec == ():
for i, net in enumerate(self._networks):
if self._inputs[i]:
x = get_nested_field(var_dict, self._inputs[i])
if isinstance(net, Network):
x = net(x)[0]
else:
x = net(x)
if self._outputs[i]:
var_dict[self._outputs[i]] = x
new_state = state
else:
new_state = [()] * len(self._networks)
for i, net in enumerate(self._networks):
if self._inputs[i]:
x = get_nested_field(var_dict, self._inputs[i])
if isinstance(net, Network):
x, new_state[i] = net(x, state[i])
else:
x = net(x)
if self._outputs[i]:
var_dict[self._outputs[i]] = x
if self._output:
x = get_nested_field(var_dict, self._output)
return x, new_state
def __getitem__(self, i):
return self._networks[i]
def make_parallel(self, n: int):
"""Create a parallelized version of this network.
Args:
n (int): the number of copies
Returns:
the parallelized version of this network
"""
new_networks = []
new_named_networks = {}
for net, input, output in zip(self._networks, self._inputs,
self._outputs):
pnet = alf.layers.make_parallel_net(net, n)
if not output:
new_networks.append((input, pnet))
else:
new_named_networks[output] = (input, pnet)
input_spec = make_parallel_spec(self._input_tensor_spec, n)
return _Sequential(new_networks, new_named_networks, self._output,
input_spec, "parallel_" + self.name)
[docs]class Parallel(Network):
"""Apply each Network in the nest of Network to the corresponding input.
Example:
.. code-block:: python
net = Parallel((module1, module2))
y, new_state = net(x, state)
is equivalent to the following:
.. code-block:: python
y0, new_state0 = module1(x[0], state[0])
y1, new_state1 = module2(x[1], state[1])
y = (y0, y1)
new_state = (new_state0, new_state1)
"""
def __init__(self, modules, input_tensor_spec=None, name="Parallel"):
"""
Args:
modules (nested nn.Module): a nest of ``torch.nn.Module`` or ``alf.nn.Network``.
input_tensor_spec (nested TensorSpec): must be provided if it cannot
be inferred from ``modules``.
name (str):
"""
if input_tensor_spec is None:
input_tensor_spec = map_structure(get_input_tensor_spec, modules)
specified = all(
map(lambda s: s is not None, flatten(input_tensor_spec)))
assert specified, (
"input_tensor_spec needs "
"to be specified if it cannot be infered from elements of "
"networks")
alf.nest.assert_same_structure_up_to(modules, input_tensor_spec)
networks = map_structure_up_to(modules, wrap_as_network, modules,
input_tensor_spec)
state_spec = map_structure(lambda net: net.state_spec, networks)
if len(flatten(state_spec)) == 0:
state_spec = ()
super().__init__(input_tensor_spec, state_spec=state_spec, name=name)
self._networks = networks
if alf.nest.is_nested(networks):
# make it a nn.Module so its parameters can be picked up by the framework
self._nets = alf.nest.utils.make_nested_module(networks)
[docs] def forward(self, inputs, state=()):
if self._state_spec == ():
output = map_structure_up_to(
self._networks, lambda net, input: net(input)[0],
self._networks, inputs)
else:
output_and_state = map_structure_up_to(
self._networks, lambda net, input, s: net(input, s),
self._networks, inputs, state)
output = map_structure_up_to(self._networks, lambda os: os[0],
output_and_state)
state = map_structure_up_to(self._networks, lambda os: os[1],
output_and_state)
return output, state
@property
def networks(self):
return self._networks
[docs] def make_parallel(self, n: int):
"""Create a parallelized version of this network.
Args:
n (int): the number of copies
Returns:
the parallelized version of this network
"""
networks = map_structure(
lambda net: alf.layers.make_parallel_net(net, n), self._networks)
input_spec = make_parallel_spec(self._input_tensor_spec, n)
return Parallel(networks, input_spec, 'parallel_' + self.name)
[docs]def Branch(*modules, input_tensor_spec=None, name="Branch", **named_modules):
"""Apply multiple networks on the same input.
Example:
.. code-block:: python
net = Branch((module1, module2))
y, new_state = net(x, state)
is equivalent to the following:
.. code-block:: python
y0, new_state0 = module1(x, state[0])
y1, new_state1 = module2(x, state[1])
y = (y0, y1)
new_state = (new_state0, new_state1)
Args:
modules (nested nn.Module | Callable): a nest of ``torch.nn.Module``
``alf.nn.Network`` or ``Callable``. Note that ``Branch(module_a, module_b)``
is equivalent to ``Branch((module_a, module_b))``
named_modules (nn.Module | Callable): a simpler way of specifying
a dict of modules. ``Branch(a=model_a, b=module_b)``
is equivalent to ``Branch(dict(a=module_a, b=module_b))``
input_tensor_spec (nested TensorSpec): must be provided if it cannot
be inferred from any one of ``modules``
name (str):
"""
# The reason that we use a wrapper function for _Branch is that Network
# does not allow *args for __init__() (see _NetworkMeta.__new__()).
return _Branch(
modules, named_modules, input_tensor_spec=input_tensor_spec, name=name)
class _Branch(Network):
def __init__(self,
modules,
named_modules,
input_tensor_spec=None,
name="Branch"):
if modules:
assert not named_modules
if len(modules) == 1:
modules = modules[0]
else:
modules = named_modules
if input_tensor_spec is None:
specs = list(map(get_input_tensor_spec, alf.nest.flatten(modules)))
specs = list(filter(lambda s: s is not None, specs))
assert specs, ("input_tensor_spec needs to be specified since it "
"cannot be inferred from any one of modules")
for spec in specs:
assert alf.utils.spec_utils.is_same_spec(spec, specs[0]), (
"modules have inconsistent input_tensor_spec: %s vs %s" %
(spec, specs[0]))
input_tensor_spec = specs[0]
networks = map_structure(
lambda net: wrap_as_network(net, input_tensor_spec), modules)
state_spec = map_structure(lambda net: net.state_spec, networks)
if len(flatten(state_spec)) == 0:
state_spec = ()
super().__init__(input_tensor_spec, state_spec=state_spec, name=name)
self._networks = networks
self._networks_flattened = flatten(networks)
if alf.nest.is_nested(networks):
# make it a nn.Module so its parameters can be picked up by the framework
self._nets = alf.nest.utils.make_nested_module(networks)
def forward(self, inputs, state=()):
if self._state_spec == ():
output = list(
map(lambda net: net(inputs)[0], self._networks_flattened))
output = pack_sequence_as(self._networks, output)
else:
state = flatten_up_to(self._networks, state)
output_state = list(
map(lambda net, s: net(inputs, s), self._networks_flattened,
state))
output = pack_sequence_as(self._networks,
[o for o, s in output_state])
state = pack_sequence_as(self._networks,
[s for o, s in output_state])
return output, state
@property
def networks(self):
return self._networks
def make_parallel(self, n: int):
"""Create a parallelized version of this network.
Args:
n (int): the number of copies
Returns:
the parallelized version of this network
"""
networks = map_structure(
lambda net: alf.layers.make_parallel_net(net, n), self._networks)
input_spec = make_parallel_spec(self._input_tensor_spec, n)
return Branch(
networks,
input_tensor_spec=input_spec,
name='parallel_' + self.name)
[docs]class Echo(Network):
"""Echo network.
Echo network uses part of the output of ``block`` of current step as part of
the input of ``block`` for the next step. In particular, if the input of ``block``
is a dictionary, it should contains two keys 'input' and 'echo', and 'echo'
will be taken from the output of the previous step. If the input of ``block``
is a tuple, the second input will be taken from the output of the previous step.
If the output is a dictionary, it should contains two keys 'output' and 'echo',
and 'echo' will be used as the input for the next step. If the output is a tuple,
the second output will be used as the input for the next step.
Note that ``block`` itself can be a recurrent network with state.
Examples:
.. code-block:: python
echo = Echo(block)
output, state = echo(real_input, state)
is equivalent to the following if the input and output of block are dicts:
.. code-block:: python
block_state, echo_input = state
block_output, block_state = block(dict(input=real_input, echo=echo_input), block_state)
output = block_output['output']
echo_output = block_output['echo']
state = (block_state, echo_output)
and is equivalent to the following if the input and output of block are tuples:
.. code-block:: python
block_state, echo_input = state
block_output, block_state = block((real_input, echo_input), block_state)
output, echo_output = block_output
state = (block_state, echo_output)
"""
def __init__(self, block, input_tensor_spec=None):
"""
Args:
block (Network): the module for performing the actual computation
input_tensor_spec (nested TensorSpec): If provided, it must match
the ``block.input_tensor_spec[0]`` or ``block.input_tensor_spec['input']``
"""
assert isinstance(
block, Network), ("block must be an instance of "
"alf.networks.Network. Got %s" % type(block))
if (isinstance(block.input_tensor_spec, tuple)
and len(block.input_tensor_spec) == 2):
self._is_tuple_input = True
real_input_spec, echo_input_spec = block.input_tensor_spec
elif (isinstance(block.input_tensor_spec, dict)
and len(block.input_tensor_spec) == 2
and 'input' in block.input_tensor_spec
and 'echo' in block.input_tensor_spec):
self._is_tuple_input = False
real_input_spec = block.input_tensor_spec['input']
echo_input_spec = block.input_tensor_spec['echo']
else:
raise ValueError(
"block.input_tensor_spec should be a tuple with "
"two elements or a dict with two keys 'input' and 'echo': %s" %
block.input_tensor_spec)
if (isinstance(block.output_spec, tuple)
and len(block.output_spec) == 2):
self._is_tuple_output = True
echo_output_spec = block.output_spec[1]
elif (isinstance(block.output_spec, dict)
and len(block.output_spec) == 2 and 'output' in block.output_spec
and 'echo' in block.output_spec):
self._is_tuple_output = False
echo_output_spec = block.output_spec['echo']
else:
raise ValueError(
"block.output_spec should be a tuple with "
"two elements or a dict with two keys 'output' and 'echo': %s"
% block.output_spec)
assert is_same_spec(echo_input_spec, echo_output_spec), (
"echo input and echo output should have same spec: %s vs. %s" %
(echo_input_spec, echo_output_spec))
if input_tensor_spec is not None:
assert is_same_spec(real_input_spec, input_tensor_spec), (
"input_tensor_spec is not same as real_input_spec: %s vs. %s" %
(input_tensor_spec, real_input_spec))
state_spec = (block.state_spec, echo_input_spec)
super().__init__(
input_tensor_spec=real_input_spec, state_spec=state_spec)
self._block = block
[docs] def forward(self, input, state):
block_state, echo_state = state
if self._is_tuple_input:
block_input = (input, echo_state)
else:
block_input = dict(input=input, echo=echo_state)
block_output, block_state = self._block(block_input, block_state)
if self._is_tuple_output:
real_output, echo_output = block_output
else:
real_output = block_output['output']
echo_output = block_output['echo']
return real_output, (block_state, echo_output)
[docs] def make_parallel(self, n: int):
"""Create a parallelized version of this network.
Args:
n (int): the number of copies
Returns:
the parallelized version of this network
"""
return Echo(
alf.layers.make_parallel_net(self._block),
make_parallel_spec(self._input_tensor_spec, n))