import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from .utils import merge_duplicate_edges
# pylint: disable=no-member
[docs]class PositionWiseFeedForward(nn.Module):
[docs] def __init__(self, num_hidden, num_ff):
super().__init__()
self.W_in = nn.Linear(num_hidden, num_ff, bias=True)
self.W_out = nn.Linear(num_ff, num_hidden, bias=True)
[docs] def forward(self, h_V):
h = F.relu(self.W_in(h_V))
h = self.W_out(h)
return h
[docs]class Normalize(nn.Module):
[docs] def __init__(self, features, epsilon=1e-6):
super().__init__()
self.gain = nn.Parameter(torch.ones(features))
self.bias = nn.Parameter(torch.zeros(features))
self.epsilon = epsilon
[docs] def forward(self, x, dim=-1):
mu = x.mean(dim, keepdim=True)
sigma = torch.sqrt(x.var(dim, keepdim=True) + self.epsilon)
gain = self.gain
bias = self.bias
# Reshape
if dim != -1:
shape = [1] * len(mu.size())
shape[dim] = self.gain.size()[0]
gain = gain.view(shape)
bias = bias.view(shape)
return gain * (x - mu) / (sigma + self.epsilon) + bias
[docs]class NodeMPNNLayer(nn.Module):
[docs] def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30):
super().__init__()
del num_heads
self.num_hidden = num_hidden
self.num_in = num_in
self.scale = scale
self.dropout = nn.Dropout(dropout)
self.norm = nn.ModuleList([Normalize(num_hidden) for _ in range(2)])
self.W1 = nn.Linear(num_hidden + num_in, num_hidden, bias=True)
self.W2 = nn.Linear(num_hidden, num_hidden, bias=True)
self.W3 = nn.Linear(num_hidden, num_hidden, bias=True)
self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
[docs] def forward(self, h_V, h_E, mask_V=None, mask_attend=None):
""" Parallel computation of full transformer layer """
# Concatenate h_V_i to h_E_ij
h_V_expand = h_V.unsqueeze(-2).expand(-1, -1, h_E.size(-2), -1)
h_EV = torch.cat([h_V_expand, h_E], -1)
h_message = self.W3(F.relu(self.W2(F.relu(self.W1(h_EV)))))
if mask_attend is not None:
h_message = mask_attend.unsqueeze(-1) * h_message
dh = torch.sum(h_message, -2) / self.scale
h_V = self.norm[0](h_V + self.dropout(dh))
# Position-wise feedforward
dh = self.dense(h_V)
h_V = self.norm[1](h_V + self.dropout(dh))
if mask_V is not None:
mask_V = mask_V.unsqueeze(-1)
h_V = mask_V * h_V
return h_V
[docs]class EdgeMPNNLayer(nn.Module):
[docs] def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30):
super().__init__()
del num_heads
self.num_hidden = num_hidden
self.num_in = num_in
self.scale = scale
self.dropout = nn.Dropout(dropout)
self.norm = nn.ModuleList([Normalize(num_hidden) for _ in range(2)])
self.W1 = nn.Linear(num_in, num_hidden, bias=True)
self.W2 = nn.Linear(num_hidden, num_hidden, bias=True)
self.W3 = nn.Linear(num_hidden, num_hidden, bias=True)
self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
[docs] def forward(self, h_E, h_EV, E_idx, mask_E=None, mask_attend=None):
""" Parallel computation of full transformer layer """
dh = self.W3(F.relu(self.W2(F.relu(self.W1(h_EV)))))
dh = merge_duplicate_edges(dh, E_idx) # does this help?
if mask_attend is not None:
dh = mask_attend.unsqueeze(-1) * dh
h_E = self.norm[0](h_E + self.dropout(dh))
# Position-wise feedforward
dh = self.dense(h_E)
h_E = self.norm[1](h_E + self.dropout(dh))
if mask_E is not None:
mask_E = mask_E.unsqueeze(-1).unsqueeze(-1)
h_E = mask_E * h_E
return h_E
[docs]class NeighborAttention(nn.Module):
[docs] def __init__(self, num_hidden, num_in, num_heads=4):
super().__init__()
self.num_heads = num_heads
self.num_hidden = num_hidden
# Self-attention layers: {queries, keys, values, output}
self.W_Q = nn.Linear(num_hidden, num_hidden, bias=False)
self.W_K = nn.Linear(num_in, num_hidden, bias=False)
self.W_V = nn.Linear(num_in, num_hidden, bias=False)
self.W_O = nn.Linear(num_hidden, num_hidden, bias=False)
[docs] def _masked_softmax(self, attend_logits, mask_attend, dim=-1):
""" Numerically stable masked softmax """
negative_inf = np.finfo(np.float32).min
mask_attn_dev = mask_attend.device
attend_logits = torch.where(mask_attend > 0, attend_logits, torch.tensor(negative_inf).to(mask_attn_dev))
attend = F.softmax(attend_logits, dim)
attend = mask_attend * attend
return attend
[docs] def forward(self, h_V, h_E, mask_attend=None):
""" Self-attention, graph-structured O(Nk)
Args:
h_V: Node features [N_batch, N_nodes, N_hidden]
h_E: Neighbor features [N_batch, N_nodes, K, N_hidden]
mask_attend: Mask for attention [N_batch, N_nodes, K]
Returns:
h_V: Node update
"""
# Queries, Keys, Values
n_batch, n_nodes, n_neighbors = h_E.shape[:3]
n_heads = self.num_heads
d = int(self.num_hidden / n_heads)
Q = self.W_Q(h_V).view([n_batch, n_nodes, 1, n_heads, 1, d])
K = self.W_K(h_E).view([n_batch, n_nodes, n_neighbors, n_heads, d, 1])
V = self.W_V(h_E).view([n_batch, n_nodes, n_neighbors, n_heads, d])
# Attention with scaled inner product
attend_logits = torch.matmul(Q, K).view([n_batch, n_nodes, n_neighbors, n_heads]).transpose(-2, -1)
attend_logits = attend_logits / np.sqrt(d)
if mask_attend is not None:
# Masked softmax
mask = mask_attend.unsqueeze(2).expand(-1, -1, n_heads, -1)
attend = self._masked_softmax(attend_logits, mask)
else:
attend = F.softmax(attend_logits, -1)
# Attentive reduction
h_V_update = torch.matmul(attend.unsqueeze(-2), V.transpose(2, 3))
h_V_update = h_V_update.view([n_batch, n_nodes, self.num_hidden])
h_V_update = self.W_O(h_V_update)
return h_V_update
[docs] def step(self, t, h_V, h_E, E_idx, mask_attend=None):
""" Self-attention for a specific time step t
Args:
h_V: Node features [N_batch, N_nodes, N_hidden]
h_E: Neighbor features [N_batch, N_nodes, K, N_in]
E_idx: Neighbor indices [N_batch, N_nodes, K]
mask_attend: Mask for attention [N_batch, N_nodes, K]
Returns:
h_V_t: Node update
"""
# Dimensions
n_batch, _, n_neighbors = h_E.shape[:3]
n_heads = self.num_heads
d = self.num_hidden / n_heads
# Per time-step tensors
h_V_t = h_V[:, t, :]
h_E_t = h_E[:, t, :, :]
E_idx_t = E_idx[:, t, :]
# Single time-step
h_V_neighbors_t = gather_nodes_t(h_V, E_idx_t)
E_t = torch.cat([h_E_t, h_V_neighbors_t], -1)
# Queries, Keys, Values
Q = self.W_Q(h_V_t).view([n_batch, 1, n_heads, 1, d])
K = self.W_K(E_t).view([n_batch, n_neighbors, n_heads, d, 1])
V = self.W_V(E_t).view([n_batch, n_neighbors, n_heads, d])
# Attention with scaled inner product
attend_logits = torch.matmul(Q, K).view([n_batch, n_neighbors, n_heads]).transpose(-2, -1)
attend_logits = attend_logits / np.sqrt(d)
if mask_attend is not None:
# Masked softmax
# [N_batch, K] -=> [N_batch, N_heads, K]
mask_t = mask_attend[:, t, :].unsqueeze(1).expand(-1, n_heads, -1)
attend = self._masked_softmax(attend_logits, mask_t)
else:
attend = F.softmax(attend_logits / np.sqrt(d), -1)
# Attentive reduction
h_V_t_update = torch.matmul(attend.unsqueeze(-2), V.transpose(1, 2))
return h_V_t_update
[docs]class EdgeEndpointAttention(nn.Module):
[docs] def __init__(self, num_hidden, num_in, num_heads=4):
super().__init__()
self.num_heads = num_heads
self.num_hidden = num_hidden
# Self-attention layers: {queries, keys, values, output}
self.W_Q = nn.Linear(num_hidden, num_hidden, bias=False)
self.W_K = nn.Linear(num_in, num_hidden, bias=False)
self.W_V = nn.Linear(num_in, num_hidden, bias=False)
self.W_O = nn.Linear(num_hidden, num_hidden, bias=False)
[docs] def _masked_softmax(self, attend_logits, mask_attend, dim=-1):
""" Numerically stable masked softmax """
negative_inf = np.finfo(np.float32).min
mask_attn_dev = mask_attend.device
attend_logits = torch.where(mask_attend > 0, attend_logits, torch.tensor(negative_inf).to(mask_attn_dev))
attend = F.softmax(attend_logits, dim)
attend = mask_attend.float() * attend
return attend
[docs] def forward(self, h_E, h_EV, E_idx, mask_attend=None):
""" Self-attention, graph-structured O(Nk)
Args:
h_E: Edge features [N_batch, N_nodes, K, N_hidden]
h_EV: Edge + endpoint features [N_batch, N_nodes, K, N_hidden * 3]
mask_attend: Mask for attention [N_batch, N_nodes, K]
Returns:
h_E_update Edge update
"""
# Queries, Keys, Values
n_batch, n_nodes, k = h_E.shape[:-1]
n_heads = self.num_heads
assert self.num_hidden % n_heads == 0
d = self.num_hidden // n_heads
Q = self.W_Q(h_E).view([n_batch, n_nodes, k, n_heads, d]).transpose(2, 3)
K = self.W_K(h_EV).view([n_batch, n_nodes, k, n_heads, d]).transpose(2, 3)
V = self.W_V(h_EV).view([n_batch, n_nodes, k, n_heads, d]).transpose(2, 3)
# Attention with scaled inner product
attend_logits = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d)
if mask_attend is not None:
# we need to reshape the src key mask for edge-edge attention
# expand to num_heads
mask = mask_attend.unsqueeze(2).expand(-1, -1, n_heads, -1).unsqueeze(-1).double()
mask_t = mask.transpose(-2, -1)
# perform outer product
mask = mask @ mask_t
mask = mask.bool()
# Masked softmax
attend = self._masked_softmax(attend_logits, mask)
else:
attend = F.softmax(attend_logits, -1)
# Attentive reduction
h_E_update = torch.matmul(attend, V).transpose(2, 3).contiguous()
h_E_update = h_E_update.view([n_batch, n_nodes, k, self.num_hidden])
h_E_update = self.W_O(h_E_update)
# nondirected edges are actually represented as two directed edges in opposite directions
# to allow information flow, merge these duplicate edges
h_E_update = merge_duplicate_edges(h_E_update, E_idx)
return h_E_update