Source code for terminator.models.layers.s2s_modules

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

from .utils import merge_duplicate_edges

# pylint: disable=no-member


[docs]class PositionWiseFeedForward(nn.Module):
[docs] def __init__(self, num_hidden, num_ff): super().__init__() self.W_in = nn.Linear(num_hidden, num_ff, bias=True) self.W_out = nn.Linear(num_ff, num_hidden, bias=True)
[docs] def forward(self, h_V): h = F.relu(self.W_in(h_V)) h = self.W_out(h) return h
[docs]class Normalize(nn.Module):
[docs] def __init__(self, features, epsilon=1e-6): super().__init__() self.gain = nn.Parameter(torch.ones(features)) self.bias = nn.Parameter(torch.zeros(features)) self.epsilon = epsilon
[docs] def forward(self, x, dim=-1): mu = x.mean(dim, keepdim=True) sigma = torch.sqrt(x.var(dim, keepdim=True) + self.epsilon) gain = self.gain bias = self.bias # Reshape if dim != -1: shape = [1] * len(mu.size()) shape[dim] = self.gain.size()[0] gain = gain.view(shape) bias = bias.view(shape) return gain * (x - mu) / (sigma + self.epsilon) + bias
[docs]class NodeTransformerLayer(nn.Module):
[docs] def __init__(self, num_hidden, num_in, num_heads=4, dropout=0.1): super().__init__() self.num_heads = num_heads self.num_hidden = num_hidden self.num_in = num_in self.dropout = nn.Dropout(dropout) self.norm = nn.ModuleList([Normalize(num_hidden) for _ in range(2)]) self.attention = NeighborAttention(num_hidden, num_in, num_heads) self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
[docs] def forward(self, h_V, h_E, mask_V=None, mask_attend=None): """ Parallel computation of full transformer layer """ # Self-attention dh = self.attention(h_V, h_E, mask_attend) h_V = self.norm[0](h_V + self.dropout(dh)) # Position-wise feedforward dh = self.dense(h_V) h_V = self.norm[1](h_V + self.dropout(dh)) if mask_V is not None: mask_V = mask_V.unsqueeze(-1) h_V = mask_V * h_V return h_V
[docs] def step(self, t, h_V, h_E, mask_V=None, mask_attend=None): """ Sequential computation of step t of a transformer layer """ # Self-attention h_V_t = h_V[:, t, :] dh_t = self.attention.step(t, h_V, h_E, mask_attend) h_V_t = self.norm[0](h_V_t + self.dropout(dh_t)) # Position-wise feedforward dh_t = self.dense(h_V_t) h_V_t = self.norm[1](h_V_t + self.dropout(dh_t)) if mask_V is not None: mask_V_t = mask_V[:, t].unsqueeze(-1) h_V_t = mask_V_t * h_V_t return h_V_t
[docs]class EdgeTransformerLayer(nn.Module):
[docs] def __init__(self, num_hidden, num_in, num_heads=4, dropout=0.1): super().__init__() self.num_heads = num_heads self.num_hidden = num_hidden self.num_in = num_in self.dropout = nn.Dropout(dropout) self.norm = nn.ModuleList([Normalize(num_hidden) for _ in range(2)]) self.attention = EdgeEndpointAttention(num_hidden, num_in, num_heads) self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
[docs] def forward(self, h_E, h_EV, E_idx, mask_E=None, mask_attend=None): """ Parallel computation of full transformer layer """ # Self-attention dh = self.attention(h_E, h_EV, E_idx, mask_attend) h_E = self.norm[0](h_E + self.dropout(dh)) # Position-wise feedforward dh = self.dense(h_E) h_E = self.norm[1](h_E + self.dropout(dh)) if mask_E is not None: mask_E = mask_E.unsqueeze(-1).unsqueeze(-1) h_E = mask_E * h_E return h_E
[docs]class NodeMPNNLayer(nn.Module):
[docs] def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30): super().__init__() del num_heads self.num_hidden = num_hidden self.num_in = num_in self.scale = scale self.dropout = nn.Dropout(dropout) self.norm = nn.ModuleList([Normalize(num_hidden) for _ in range(2)]) self.W1 = nn.Linear(num_hidden + num_in, num_hidden, bias=True) self.W2 = nn.Linear(num_hidden, num_hidden, bias=True) self.W3 = nn.Linear(num_hidden, num_hidden, bias=True) self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
[docs] def forward(self, h_V, h_E, mask_V=None, mask_attend=None): """ Parallel computation of full transformer layer """ # Concatenate h_V_i to h_E_ij h_V_expand = h_V.unsqueeze(-2).expand(-1, -1, h_E.size(-2), -1) h_EV = torch.cat([h_V_expand, h_E], -1) h_message = self.W3(F.relu(self.W2(F.relu(self.W1(h_EV))))) if mask_attend is not None: h_message = mask_attend.unsqueeze(-1) * h_message dh = torch.sum(h_message, -2) / self.scale h_V = self.norm[0](h_V + self.dropout(dh)) # Position-wise feedforward dh = self.dense(h_V) h_V = self.norm[1](h_V + self.dropout(dh)) if mask_V is not None: mask_V = mask_V.unsqueeze(-1) h_V = mask_V * h_V return h_V
[docs]class EdgeMPNNLayer(nn.Module):
[docs] def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30): super().__init__() del num_heads self.num_hidden = num_hidden self.num_in = num_in self.scale = scale self.dropout = nn.Dropout(dropout) self.norm = nn.ModuleList([Normalize(num_hidden) for _ in range(2)]) self.W1 = nn.Linear(num_in, num_hidden, bias=True) self.W2 = nn.Linear(num_hidden, num_hidden, bias=True) self.W3 = nn.Linear(num_hidden, num_hidden, bias=True) self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
[docs] def forward(self, h_E, h_EV, E_idx, mask_E=None, mask_attend=None): """ Parallel computation of full transformer layer """ dh = self.W3(F.relu(self.W2(F.relu(self.W1(h_EV))))) dh = merge_duplicate_edges(dh, E_idx) # does this help? if mask_attend is not None: dh = mask_attend.unsqueeze(-1) * dh h_E = self.norm[0](h_E + self.dropout(dh)) # Position-wise feedforward dh = self.dense(h_E) h_E = self.norm[1](h_E + self.dropout(dh)) if mask_E is not None: mask_E = mask_E.unsqueeze(-1).unsqueeze(-1) h_E = mask_E * h_E return h_E
[docs]class NeighborAttention(nn.Module):
[docs] def __init__(self, num_hidden, num_in, num_heads=4): super().__init__() self.num_heads = num_heads self.num_hidden = num_hidden # Self-attention layers: {queries, keys, values, output} self.W_Q = nn.Linear(num_hidden, num_hidden, bias=False) self.W_K = nn.Linear(num_in, num_hidden, bias=False) self.W_V = nn.Linear(num_in, num_hidden, bias=False) self.W_O = nn.Linear(num_hidden, num_hidden, bias=False)
[docs] def _masked_softmax(self, attend_logits, mask_attend, dim=-1): """ Numerically stable masked softmax """ negative_inf = np.finfo(np.float32).min mask_attn_dev = mask_attend.device attend_logits = torch.where(mask_attend > 0, attend_logits, torch.tensor(negative_inf).to(mask_attn_dev)) attend = F.softmax(attend_logits, dim) attend = mask_attend * attend return attend
[docs] def forward(self, h_V, h_E, mask_attend=None): """ Self-attention, graph-structured O(Nk) Args: h_V: Node features [N_batch, N_nodes, N_hidden] h_E: Neighbor features [N_batch, N_nodes, K, N_hidden] mask_attend: Mask for attention [N_batch, N_nodes, K] Returns: h_V: Node update """ # Queries, Keys, Values n_batch, n_nodes, n_neighbors = h_E.shape[:3] n_heads = self.num_heads d = int(self.num_hidden / n_heads) Q = self.W_Q(h_V).view([n_batch, n_nodes, 1, n_heads, 1, d]) K = self.W_K(h_E).view([n_batch, n_nodes, n_neighbors, n_heads, d, 1]) V = self.W_V(h_E).view([n_batch, n_nodes, n_neighbors, n_heads, d]) # Attention with scaled inner product attend_logits = torch.matmul(Q, K).view([n_batch, n_nodes, n_neighbors, n_heads]).transpose(-2, -1) attend_logits = attend_logits / np.sqrt(d) if mask_attend is not None: # Masked softmax mask = mask_attend.unsqueeze(2).expand(-1, -1, n_heads, -1) attend = self._masked_softmax(attend_logits, mask) else: attend = F.softmax(attend_logits, -1) # Attentive reduction h_V_update = torch.matmul(attend.unsqueeze(-2), V.transpose(2, 3)) h_V_update = h_V_update.view([n_batch, n_nodes, self.num_hidden]) h_V_update = self.W_O(h_V_update) return h_V_update
[docs] def step(self, t, h_V, h_E, E_idx, mask_attend=None): """ Self-attention for a specific time step t Args: h_V: Node features [N_batch, N_nodes, N_hidden] h_E: Neighbor features [N_batch, N_nodes, K, N_in] E_idx: Neighbor indices [N_batch, N_nodes, K] mask_attend: Mask for attention [N_batch, N_nodes, K] Returns: h_V_t: Node update """ # Dimensions n_batch, _, n_neighbors = h_E.shape[:3] n_heads = self.num_heads d = self.num_hidden / n_heads # Per time-step tensors h_V_t = h_V[:, t, :] h_E_t = h_E[:, t, :, :] E_idx_t = E_idx[:, t, :] # Single time-step h_V_neighbors_t = gather_nodes_t(h_V, E_idx_t) E_t = torch.cat([h_E_t, h_V_neighbors_t], -1) # Queries, Keys, Values Q = self.W_Q(h_V_t).view([n_batch, 1, n_heads, 1, d]) K = self.W_K(E_t).view([n_batch, n_neighbors, n_heads, d, 1]) V = self.W_V(E_t).view([n_batch, n_neighbors, n_heads, d]) # Attention with scaled inner product attend_logits = torch.matmul(Q, K).view([n_batch, n_neighbors, n_heads]).transpose(-2, -1) attend_logits = attend_logits / np.sqrt(d) if mask_attend is not None: # Masked softmax # [N_batch, K] -=> [N_batch, N_heads, K] mask_t = mask_attend[:, t, :].unsqueeze(1).expand(-1, n_heads, -1) attend = self._masked_softmax(attend_logits, mask_t) else: attend = F.softmax(attend_logits / np.sqrt(d), -1) # Attentive reduction h_V_t_update = torch.matmul(attend.unsqueeze(-2), V.transpose(1, 2)) return h_V_t_update
[docs]class EdgeEndpointAttention(nn.Module):
[docs] def __init__(self, num_hidden, num_in, num_heads=4): super().__init__() self.num_heads = num_heads self.num_hidden = num_hidden # Self-attention layers: {queries, keys, values, output} self.W_Q = nn.Linear(num_hidden, num_hidden, bias=False) self.W_K = nn.Linear(num_in, num_hidden, bias=False) self.W_V = nn.Linear(num_in, num_hidden, bias=False) self.W_O = nn.Linear(num_hidden, num_hidden, bias=False)
[docs] def _masked_softmax(self, attend_logits, mask_attend, dim=-1): """ Numerically stable masked softmax """ negative_inf = np.finfo(np.float32).min mask_attn_dev = mask_attend.device attend_logits = torch.where(mask_attend > 0, attend_logits, torch.tensor(negative_inf).to(mask_attn_dev)) attend = F.softmax(attend_logits, dim) attend = mask_attend.float() * attend return attend
[docs] def forward(self, h_E, h_EV, E_idx, mask_attend=None): """ Self-attention, graph-structured O(Nk) Args: h_E: Edge features [N_batch, N_nodes, K, N_hidden] h_EV: Edge + endpoint features [N_batch, N_nodes, K, N_hidden * 3] mask_attend: Mask for attention [N_batch, N_nodes, K] Returns: h_E_update Edge update """ # Queries, Keys, Values n_batch, n_nodes, k = h_E.shape[:-1] n_heads = self.num_heads assert self.num_hidden % n_heads == 0 d = self.num_hidden // n_heads Q = self.W_Q(h_E).view([n_batch, n_nodes, k, n_heads, d]).transpose(2, 3) K = self.W_K(h_EV).view([n_batch, n_nodes, k, n_heads, d]).transpose(2, 3) V = self.W_V(h_EV).view([n_batch, n_nodes, k, n_heads, d]).transpose(2, 3) # Attention with scaled inner product attend_logits = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d) if mask_attend is not None: # we need to reshape the src key mask for edge-edge attention # expand to num_heads mask = mask_attend.unsqueeze(2).expand(-1, -1, n_heads, -1).unsqueeze(-1).double() mask_t = mask.transpose(-2, -1) # perform outer product mask = mask @ mask_t mask = mask.bool() # Masked softmax attend = self._masked_softmax(attend_logits, mask) else: attend = F.softmax(attend_logits, -1) # Attentive reduction h_E_update = torch.matmul(attend, V).transpose(2, 3).contiguous() h_E_update = h_E_update.view([n_batch, n_nodes, k, self.num_hidden]) h_E_update = self.W_O(h_E_update) # nondirected edges are actually represented as two directed edges in opposite directions # to allow information flow, merge these duplicate edges h_E_update = merge_duplicate_edges(h_E_update, E_idx) return h_E_update