Source code for alf.networks.transformer_networks

# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn

import alf
from alf.networks import PreprocessorNetwork
from alf.networks.memory import FIFOMemory
from alf.nest.utils import NestConcat

import functools
import torch.nn.functional as F
from alf.initializers import variance_scaling_init
import alf.layers as layers


[docs]@alf.configurable class TransformerNetwork(PreprocessorNetwork): """A Network composed of Memory and TransformerBlock. The following is the pseudocode for the computation: .. code-block:: python for i in range(num_prememory_layers): core, inputs = T_i([core, inputs], [core, inputs]) for j in range(num_memory_layers): new_core, inputs = TM_j([memory_j, core, inputs], [core, inputs]) memory_j.write(core) core = new_core return core, new_memory_state where T_i denotes the ``TransformerBlock`` for the i-th prememory layers and TM_j denotes the ``TransformerBlock`` for the j-th memory layers. memory_j is an ``FIFOMemory`` object (not to be confused with the ``memory`` argument of ``TransformerBlock.forward() function``) The core embedding serves the same purpose of [CLS] in the BERT model in [1], which is to generate a fixed dimensional representation for downstream tasks. Different from BERT, which only has one [CLS] embedding, we allow the option of having multiple core embeddings. In addition to generating a fixed dimensional representation, the core embedding is also used to update the memory. [1]. Devlin et al. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding """ def __init__(self, input_tensor_spec, num_prememory_layers, num_attention_heads, d_ff=None, core_size=1, use_core_embedding=True, memory_size=0, num_memory_layers=0, return_core_only=True, centralized_memory=True, input_preprocessors=None, name="TransformerNetwork"): """ Args: input_tensor_spec (nested TensorSpec): the (nested) tensor spec of the input. If ``input_tensor_spec`` is not nested, it should represent a rank-2 tensor of shape ``[input_size, d_model]``, where ``input_size`` is the length of the input sequence, and ``d_model`` is the dimension of embedding. num_prememory_layers (int): number of TransformerBlock calculation without using memory num_attention_heads (int): number of attention heads for each ``TransformerBlock`` d_ff (int): the size of the hidden layer of the feedforward network in each ``TransformerBlock``. If None, ``TransformerBlock`` will calculate it as ``4*d_model``. memory_size (int): size of memory. num_memory_layers (int): number of TransformerBlock calculation using memory return_core_only (bool): If True, only return the core embedding. Otherwise, return all embeddings core_size (int): size of core (i.e. number of embeddings of core) use_core_embedding (bool): whether to use learnable core embedding. If True, will use additional learnable core embedding to augment the input. If False, the first ``core_size`` embeddings of the input are treated as core. centralized_memory (bool): if False, there will be a separate memory for each memory layers. if True, there will be a single memory for all the memroy layers and it is updated using the last core embeddings. input_preprocessors (nested Network|nn.Module): a nest of stateless preprocessor networks, each of which will be applied to the corresponding input. If not None, then it must have the same structure with ``input_tensor_spec``. If any element is None, then it will be treated as math_ops.identity. This arg is helpful if you want to have separate preprocessings for different inputs by configuring a gin file without changing the code. For example, embedding a discrete input before concatenating it to another continuous vector. The output_spec of each input preprocessor i should be [input_size_i, d_model]. The result of all the preprocessors will be concatenated as a Tensor of shape ``[batch_size, input_size, d_model]``, where ``input_size = sum_i input_size_i``. """ preprocessing_combiner = None if input_preprocessors is not None: preprocessing_combiner = NestConcat(dim=-2) super().__init__( input_tensor_spec, input_preprocessors, preprocessing_combiner=preprocessing_combiner, name=name) assert self._processed_input_tensor_spec.ndim == 2 input_size, d_model = self._processed_input_tensor_spec.shape if num_memory_layers > 0: assert memory_size > 0, ("memory_size needs to be set if " "num_memory_layers > 0") if centralized_memory: self._memories = [FIFOMemory(d_model, memory_size)] else: self._memories = [ FIFOMemory(d_model, memory_size) for _ in range(num_memory_layers) ] else: self._memories = [] self._centralized_memory = centralized_memory self._core_size = core_size if use_core_embedding: self._core_embedding = nn.Parameter( torch.Tensor(1, core_size, d_model)) nn.init.uniform_(self._core_embedding, -0.1, 0.1) else: self._core_embedding = None self._state_spec = [mem.state_spec for mem in self._memories] self._num_memory_layers = num_memory_layers self._num_prememory_layers = num_prememory_layers self._transformers = nn.ModuleList() for i in range(num_prememory_layers): self._transformers.append( alf.layers.TransformerBlock( d_model=d_model, d_ff=d_ff, num_heads=num_attention_heads, memory_size=input_size + core_size, positional_encoding='abs' if i == 0 else 'none')) for i in range(num_memory_layers): self._transformers.append( alf.layers.TransformerBlock( d_model=d_model, d_ff=d_ff, num_heads=num_attention_heads, memory_size=memory_size + input_size + core_size, positional_encoding='abs' if i == 0 else 'none')) self._return_core_only = return_core_only @property def state_spec(self): return self._state_spec
[docs] def forward(self, inputs, state=()): """ Args: inputs (nested Tensor): consistent with ``input_tensor_spec`` provided at ``__init__()`` state (nested Tensor): states Returns: - Tensor: shape is [B, core_size * d_model] if ``return_core_only``, and [B, core_size + input_size, d_model] if not ``return_core_only``, where ``input_size`` is the number of embeddings from the (processed) input. - nested Tensor: network states. """ z, _ = super().forward(inputs, state) batch_size = z.shape[0] if self._core_embedding is not None: core_embedding = self._core_embedding.expand(batch_size, -1, -1) query = torch.cat([core_embedding, z], dim=-2) else: query = z for i in range(self._num_prememory_layers): query = self._transformers[i].forward(query) if self._num_memory_layers > 0 and self._centralized_memory: memory = self._memories[0] memory.from_states(state[0]) mem = memory.memory() for i in range(self._num_memory_layers): transformer = self._transformers[self._num_prememory_layers + i] query = transformer.forward( memory=torch.cat([mem, query], dim=-2), query=query) memory.write(query[:, :self._core_size, :]) else: for i in range(self._num_memory_layers): memory = self._memories[i] memory.from_states(state[i]) transformer = self._transformers[self._num_prememory_layers + i] new_query = transformer.forward( memory=torch.cat([memory.memory(), query], dim=-2), query=query) memory.write(query[:, :self._core_size, :]) query = new_query new_state = [mem.states for mem in self._memories] if self._return_core_only: return query[:, :self._core_size, :].reshape(batch_size, -1), new_state else: return query, new_state
[docs]@alf.configurable class SocialAttentionNetwork(PreprocessorNetwork): """Simple graph encoding network, which takes as input a set of objects and outputs one encoded feature vector. Reference: Leurent et al "Social Attention for Autonomous Decision-Making in Dense Traffic", arXiv:1911.12250 """ def __init__(self, input_tensor_spec, input_preprocessors=None, preprocessing_combiner=None, fc_layer_params=(128, 128), activation=torch.relu_, kernel_initializer=None, use_fc_bn=False, num_of_heads=1, last_layer_size=None, last_activation=None, last_kernel_initializer=None, name="SocialAttentionNetwork"): """ Args: input_tensor_spec (nested TensorSpec): the (nested) tensor spec of the input. If nested, then ``preprocessing_combiner`` must not be None. input_preprocessors (nested InputPreprocessor): a nest of ``InputPreprocessor``, each of which will be applied to the corresponding input. If not None, then it must have the same structure with ``input_tensor_spec``. This arg is helpful if you want to have separate preprocessings for different inputs by configuring a gin file without changing the code. For example, embedding a discrete input before concatenating it to another continuous vector. preprocessing_combiner (NestCombiner): preprocessing called on complex inputs. Note that this combiner must also accept ``input_tensor_spec`` as the input to compute the processed tensor spec. For example, see ``alf.nest.utils.NestConcat``. This arg is helpful if you want to combine inputs by configuring a gin file without changing the code. fc_layer_params (tuple[int]): a tuple of integers representing FC layer sizes for generating embeddings. activation (nn.functional): activation used for all the layers but the last layer. kernel_initializer (Callable): initializer for all the layers but the last layer. If None, a variance_scaling_initializer will be used. use_fc_bn (bool): whether use Batch Normalization for fc layers. num_of_heads (int): number of heads for the mult-head attention last_layer_size (None): nt used; for interface compatibility last_activation (None): not used; for interface compatibility last_kernel_initializer (None): not used; for interface compatibility last_use_fc_bn (None): not used; for interface compatibility name (str): """ super().__init__( input_tensor_spec, input_preprocessors, preprocessing_combiner=preprocessing_combiner, name=name) if kernel_initializer is None: kernel_initializer = functools.partial( variance_scaling_init, mode='fan_in', distribution='truncated_normal', nonlinearity=activation) embedding_layers = nn.ModuleList() assert self._processed_input_tensor_spec.ndim == 2, ( "expect the " "processed spec to have the shape of [entity_num, feature_dim]") input_size = self._processed_input_tensor_spec.shape[-1] for size in fc_layer_params: embedding_layers.append( layers.FC( input_size, size, activation=activation, use_bn=use_fc_bn, kernel_initializer=kernel_initializer)) input_size = size self._embedding_layers = embedding_layers fea_dim = input_size assert fea_dim % num_of_heads == 0, "improper value for num_of_heads" self._num_of_heads = num_of_heads self._fea_dim_per_head = fea_dim // num_of_heads # attention related layers self._value_proj = layers.FC( fea_dim, fea_dim, use_bias=False, kernel_initializer=kernel_initializer) self._key_proj = layers.FC( fea_dim, fea_dim, use_bias=False, kernel_initializer=kernel_initializer) self._query_proj = layers.FC( fea_dim, fea_dim, use_bias=False, kernel_initializer=kernel_initializer) self._simple_attention = alf.layers.SimpleAttention()
[docs] def forward(self, inputs, state=()): """ Args: inputs (Tensor): with the shape of [B, N, d], where B denotes batch size, N the number of entities, and d the feature dimension state (nested Tensor): states Returns: - Tensor: shape is [B, d'], where d' denotes the output dimension of the last layer specified by fc_layer_params (i.e. fc_layer_params[-1]) """ x, _ = super().forward(inputs, state) B, N, d = x.shape x = x.reshape(B * N, -1) # forward through embedding layers shared across all entities for i, net in enumerate(self._embedding_layers): x = net(x) # [B, N, d'] (batch, entities, fea_dim) X = x.reshape(B, N, -1) key = X[:, 0] # [B, head * d'] -> [B, 1, head, d'] query = self._query_proj(key).reshape(B, 1, self._num_of_heads, self._fea_dim_per_head) key = self._key_proj(X).reshape(B, N, self._num_of_heads, self._fea_dim_per_head) value = self._value_proj(X).reshape(B, N, self._num_of_heads, self._fea_dim_per_head) # [B, N, head, d'] -> [B, head, N, d'] query = query.permute(0, 2, 1, 3) key = key.permute(0, 2, 1, 3) value = value.permute(0, 2, 1, 3) v, _ = self._simple_attention(query=query, key=key, value=value) out = v.reshape(B, -1) return out, state