Source code for terminator.models.layers.term.graph.s2s

""" TERM MPNN modules

This file contains Attention and Message Passing implementations
of the TERM MPNN. """

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

from terminator.models.layers.s2s_modules import (Normalize, PositionWiseFeedForward)
from terminator.models.layers.utils import (gather_term_nodes, cat_term_neighbors_nodes, cat_term_edge_endpoints,
                                            merge_duplicate_term_edges)

# pylint: disable=no-member


[docs]class TERMNeighborAttention(nn.Module): """ TERM Neighbor Attention A module which computes a node update using self-attention over all neighboring TERM residues and the edges connecting them. Attributes ---------- W_Q : nn.Linear Projection matrix for querys W_K : nn.Linear Projection matrix for keys W_V : nn.Linear Projection matrix for values W_O : nn.Linear Output layer """
[docs] def __init__(self, num_hidden, num_in, num_heads=4): """ Args ---- num_hidden : int Hidden dimension, and dimensionality of querys num_in : int Dimensionality of keys and values num_heads : int, default=4 Number of heads to use in Attention """ super().__init__() self.num_heads = num_heads self.num_hidden = num_hidden self.num_in = num_in # Self-attention layers: {queries, keys, values, output} self.W_Q = nn.Linear(num_hidden, num_hidden, bias=False) self.W_K = nn.Linear(num_in, num_hidden, bias=False) self.W_V = nn.Linear(num_in, num_hidden, bias=False) self.W_O = nn.Linear(num_hidden, num_hidden, bias=False)
[docs] def _masked_softmax(self, attend_logits, mask_attend, dim=-1): """ Numerically stable masked softmax Args ---- attend_logits : torch.Tensor Attention logits mask_attend: torch.ByteTensor Mask on Attention logits dim : int, default=-1 Dimension to perform softmax along Returns ------- attend : torch.Tensor Softmaxed :code:`attend_logits` """ negative_inf = np.finfo(np.float32).min mask_attn_dev = mask_attend.device attend_logits = torch.where(mask_attend > 0, attend_logits, torch.tensor(negative_inf).to(mask_attn_dev)) attend = F.softmax(attend_logits, dim) attend = mask_attend.float() * attend return attend
[docs] def forward(self, h_V, h_EV, mask_attend=None): """ Self-attention update over nodes of a TERM graph Args ---- h_V: torch.Tensor Central node features Shape: n_batch x n_terms x n_nodes x n_hidden h_EV: torch.Tensor Neighbor features, which includes the node vector concatenated onto the edge connecting the central node to the neighbor node Shape: n_batch x n_terms x n_nodes x n_neighbors x n_in mask_attend: torch.ByteTensor or None Mask for attention regarding neighbors Shape: n_batch x n_terms x n_nodes x k Returns ------- h_V_update: torch.Tensor Node embedding update Shape: n_batch x n_terms x n_nodes x n_hidden """ # Queries, Keys, Values n_batch, n_terms, n_nodes, n_neighbors = h_EV.shape[:4] n_heads = self.num_heads d = int(self.num_hidden / n_heads) Q = self.W_Q(h_V).view([n_batch, n_terms, n_nodes, 1, n_heads, 1, d]) K = self.W_K(h_EV).view([n_batch, n_terms, n_nodes, n_neighbors, n_heads, d, 1]) V = self.W_V(h_EV).view([n_batch, n_terms, n_nodes, n_neighbors, n_heads, d]) # Attention with scaled inner product attend_logits = torch.matmul(Q, K).view([n_batch, n_terms, n_nodes, n_neighbors, n_heads]).transpose(-2, -1) attend_logits = attend_logits / np.sqrt(d) if mask_attend is not None: mask = mask_attend.unsqueeze(3).expand(-1, -1, -1, n_heads, -1) attend = self._masked_softmax(attend_logits, mask) else: attend = F.softmax(attend_logits, -1) # Attentive reduction h_V_update = torch.matmul(attend.unsqueeze(-2), V.transpose(3, 4)) h_V_update = h_V_update.view([n_batch, n_terms, n_nodes, self.num_hidden]) h_V_update = self.W_O(h_V_update) return h_V_update
[docs]class TERMNodeTransformerLayer(nn.Module): """ TERM Node Transformer Layer A TERM Node Transformer Layer that updates nodes via TERMNeighborAttention Attributes ---------- attention: TERMNeighborAttention Transformer Attention mechanism dense: PositionWiseFeedForward Transformer position-wise FFN """
[docs] def __init__(self, num_hidden, num_in, num_heads=4, dropout=0.1): """ Args ---- num_hidden : int Hidden dimension, and dimensionality of querys in TERMNeighborAttention num_in : int Dimensionality of keys and values num_heads : int, default=4 Number of heads to use in TERMNeighborAttention dropout : float, default=0.1 Dropout rate """ super().__init__() self.num_heads = num_heads self.num_hidden = num_hidden self.dropout = nn.Dropout(dropout) self.norm = nn.ModuleList([Normalize(num_hidden) for _ in range(2)]) self.attention = TERMNeighborAttention(num_hidden, num_in, num_heads) self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
[docs] def forward(self, h_V, h_VE, mask_V=None, mask_attend=None): """ Apply one Transformer update on nodes in a TERM graph Args ---- h_V: torch.Tensor Central node features Shape: n_batch x n_terms x n_nodes x n_hidden h_VE: torch.Tensor Neighbor features, which includes the node vector concatenated onto the edge connecting the central node to the neighbor node Shape: n_batch x n_terms x n_nodes x n_neighbors x n_in mask_V : torch.ByteTensor or None Mask for attention regarding TERM residues Shape : n_batch x n_terms x n_nodes mask_attend: torch.ByteTensor or None Mask for attention regarding neighbors Shape: n_batch x n_terms x n_nodes x k Returns ------- h_V: torch.Tensor Updated node embeddings Shape: n_batch x n_terms x n_nodes x n_hidden """ # Self-attention dh = self.attention(h_V, h_VE, mask_attend) h_V = self.norm[0](h_V + self.dropout(dh)) # Position-wise feedforward dh = self.dense(h_V) h_V = self.norm[1](h_V + self.dropout(dh)) # Apply node mask if mask_V is not None: mask_V = mask_V.unsqueeze(-1) h_V = mask_V * h_V return h_V
[docs]class TERMEdgeEndpointAttention(nn.Module): """ TERM Edge Endpoint Attention A module which computes an edge update using self-attention over all edges that it share a 'home residue' with, as well as the nodes that form those edges. Attributes ---------- W_Q : nn.Linear Projection matrix for querys W_K : nn.Linear Projection matrix for keys W_V : nn.Linear Projection matrix for values W_O : nn.Linear Output layer """
[docs] def __init__(self, num_hidden, num_in, num_heads=4): """ Args ---- num_hidden : int Hidden dimension, and dimensionality of querys num_in : int Dimensionality of keys and values num_heads : int, default=4 Number of heads to use in Attention """ super().__init__() self.num_heads = num_heads self.num_hidden = num_hidden # Self-attention layers: {queries, keys, values, output} self.W_Q = nn.Linear(num_hidden, num_hidden, bias=False) self.W_K = nn.Linear(num_in, num_hidden, bias=False) self.W_V = nn.Linear(num_in, num_hidden, bias=False) self.W_O = nn.Linear(num_hidden, num_hidden, bias=False)
[docs] def _masked_softmax(self, attend_logits, mask_attend, dim=-1): """ Numerically stable masked softmax Args ---- attend_logits : torch.Tensor Attention logits mask_attend: torch.ByteTensor Mask on Attention logits dim : int, default=-1 Dimension to perform softmax along Returns ------- attend : torch.Tensor Softmaxed :code:`attend_logits` """ negative_inf = np.finfo(np.float32).min mask_attn_dev = mask_attend.device attend_logits = torch.where(mask_attend > 0, attend_logits, torch.tensor(negative_inf).to(mask_attn_dev)) attend = F.softmax(attend_logits, dim) attend = mask_attend.float() * attend return attend
[docs] def forward(self, h_E, h_EV, E_idx, mask_attend=None): """ Self-attention update over edges in a TERM graph Args ---- h_E: torch.Tensor Edge features in kNN dense form Shape: n_batch x n_terms x n_nodes x k x n_hidden h_EV: torch.Tensor 'Neighbor' edge features, or all edges which share a 'central residue' with that edge, as well as the node features for both nodes that compose that edge. Shape: n_batch x n_terms x n_nodes x k x n_in mask_attend: torch.ByteTensor or None Mask for attention regarding neighbors Shape: n_batch x n_terms x n_nodes x k Returns ------- h_E_update: torch.Tensor Update for edge embeddings Shape: n_batch x n_terms x n_nodes x k x n_hidden """ # Queries, Keys, Values n_batch, n_terms, n_aa, n_neighbors = h_E.shape[:-1] n_heads = self.num_heads assert self.num_hidden % n_heads == 0 d = self.num_hidden // n_heads Q = self.W_Q(h_E).view([n_batch, n_terms, n_aa, n_neighbors, n_heads, d]).transpose(3, 4) K = self.W_K(h_EV).view([n_batch, n_terms, n_aa, n_neighbors, n_heads, d]).transpose(3, 4) V = self.W_V(h_EV).view([n_batch, n_terms, n_aa, n_neighbors, n_heads, d]).transpose(3, 4) # Attention with scaled inner product attend_logits = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d) if mask_attend is not None: # we need to reshape the src key mask for edge-edge attention # expand to num_heads mask = mask_attend.unsqueeze(3).expand(-1, -1, -1, n_heads, -1).unsqueeze(-1).double() mask_t = mask.transpose(-2, -1) # perform outer product mask = mask @ mask_t mask = mask.bool() # Masked softmax attend = self._masked_softmax(attend_logits, mask) else: attend = F.softmax(attend_logits, -1) # Attentive reduction h_E_update = torch.matmul(attend, V).transpose(3, 4).contiguous() h_E_update = h_E_update.view([n_batch, n_terms, n_aa, n_neighbors, self.num_hidden]) h_E_update = self.W_O(h_E_update) # nondirected edges are actually represented as two directed edges in opposite directions # to allow information flow, merge these duplicate edges h_E_update = merge_duplicate_term_edges(h_E_update, E_idx) return h_E_update
[docs]class TERMEdgeTransformerLayer(nn.Module): """ TERM Edge Transformer Layer A TERM Edge Transformer Layer that updates edges via TERMEdgeEndpointAttention Attributes ---------- attention: TERMEdgeEndpointAttention Transformer Attention mechanism dense: PositionWiseFeedForward Transformer position-wise FFN """
[docs] def __init__(self, num_hidden, num_in, num_heads=4, dropout=0.1): """ Args ---- num_hidden : int Hidden dimension, and dimensionality of querys in TERMNeighborAttention num_in : int Dimensionality of keys and values num_heads : int, default=4 Number of heads to use in TERMNeighborAttention dropout : float, default=0.1 Dropout rate """ super().__init__() self.num_heads = num_heads self.num_hidden = num_hidden self.num_in = num_in self.dropout = nn.Dropout(dropout) self.norm = nn.ModuleList([Normalize(num_hidden) for _ in range(2)]) self.attention = TERMEdgeEndpointAttention(num_hidden, num_in, num_heads) self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
[docs] def forward(self, h_E, h_EV, E_idx, mask_E=None, mask_attend=None): """ Apply one Transformer update on edges in a TERM graph Args ---- h_E: torch.Tensor Edge features in kNN dense form Shape: n_batch x n_terms x n_nodes x k x n_hidden h_EV: torch.Tensor 'Neighbor' edge features, or all edges which share a 'central residue' with that edge, as well as the node features for both nodes that compose that edge. Shape: n_batch x n_terms x n_nodes x k x n_in mask_E : torch.ByteTensor or None Mask for attention regarding TERM edges Shape : n_batch x n_terms x n_nodes mask_attend: torch.ByteTensor or None Mask for attention regarding 'neighbor' edges Shape: n_batch x n_terms x n_nodes x k Returns ------- h_E: torch.Tensor Updated edge embeddings Shape: n_batch x n_terms x n_nodes x k x n_hidden """ # Self-attention dh = self.attention(h_E, h_EV, E_idx, mask_attend) h_E = self.norm[0](h_E + self.dropout(dh)) # Position-wise feedforward dh = self.dense(h_E) h_E = self.norm[1](h_E + self.dropout(dh)) if mask_E is not None: mask_E = mask_E.unsqueeze(-1) h_E = mask_E * h_E return h_E
[docs]class TERMNodeMPNNLayer(nn.Module): """ TERM Node MPNN Layer A TERM Node MPNN Layer that updates nodes via generating messages and feeding the update through a feedforward network Attributes ---------- W1, W2, W3: nn.Linear Layers for message computation dense: PositionWiseFeedForward Transformer position-wise FFN """ # pylint: disable=unused-argument # num_heads is not used, but exists for compatibility with options for the Attention equivalent
[docs] def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=None): """ Args ---- num_hidden : int Hidden dimension, and dimensionality of querys in TERMNeighborAttention num_in : int Dimensionality of keys and values num_heads : int, default=4 Number of heads to use in TERMNeighborAttention dropout : float, default=0.1 Dropout rate scale : int or None, default=None Scaling integer by which to divde the sum of computed messages. If None, the mean of the messages will be used instead. """ super().__init__() self.num_hidden = num_hidden self.num_in = num_in self.scale = scale self.dropout = nn.Dropout(dropout) self.norm = nn.ModuleList([Normalize(num_hidden) for _ in range(2)]) self.W1 = nn.Linear(num_hidden + num_in, num_hidden, bias=True) self.W2 = nn.Linear(num_hidden, num_hidden, bias=True) self.W3 = nn.Linear(num_hidden, num_hidden, bias=True) self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
[docs] def forward(self, h_V, h_E, mask_V=None, mask_attend=None): """ Apply one MPNN update on nodes in a TERM graph Args ---- h_V: torch.Tensor Central node features Shape: n_batch x n_terms x n_nodes x n_hidden h_VE: torch.Tensor Neighbor features, which includes the node vector concatenated onto the edge connecting the central node to the neighbor node Shape: n_batch x n_terms x n_nodes x n_neighbors x n_in mask_V : torch.ByteTensor or None Mask for message-passing regarding TERM residues Shape : n_batch x n_terms x n_nodes mask_attend: torch.ByteTensor or None Mask for message-passing regarding neighbors Shape: n_batch x n_terms x n_nodes x k Returns ------- h_V: torch.Tensor Updated node embeddings Shape: n_batch x n_terms x n_nodes x n_hidden """ # Concatenate h_V_i to h_E_ij h_V_expand = h_V.unsqueeze(-2).expand(-1, -1, -1, h_E.size(-2), -1) h_EV = torch.cat([h_V_expand, h_E], -1) h_message = self.W3(F.relu(self.W2(F.relu(self.W1(h_EV))))) if mask_attend is not None: h_message = mask_attend.unsqueeze(-1) * h_message # note: this inherently decreases the magnitudes of messages for smaller TERMs # that wasn't intentional, but maybe that's a good thing if self.scale is None: dh = torch.mean(h_message, dim=-2) else: dh = torch.sum(h_message, dim=-2) / self.scale h_V = self.norm[0](h_V + self.dropout(dh)) # Position-wise feedforward dh = self.dense(h_V) h_V = self.norm[1](h_V + self.dropout(dh)) if mask_V is not None: mask_V = mask_V.unsqueeze(-1) h_V = mask_V * h_V return h_V
[docs]class TERMEdgeMPNNLayer(nn.Module): """ TERM Edge MPNN Layer A TERM Edge MPNN Layer that updates edges via generating messages and feeding the update through a feedforward network Attributes ---------- W1, W2, W3: nn.Linear Layers for message computation dense: PositionWiseFeedForward Transformer position-wise FFN """ # pylint: disable=unused-argument # num_heads is not used, but exists for compatibility with options for the Attention equivalent
[docs] def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None): """ Args ---- num_hidden : int Hidden dimension, and dimensionality of querys in TERMNeighborAttention num_in : int Dimensionality of keys and values num_heads : int, default=4 Number of heads to use in TERMNeighborAttention dropout : float, default=0.1 Dropout rate """ super().__init__() self.num_hidden = num_hidden self.num_in = num_in self.dropout = nn.Dropout(dropout) self.norm = nn.ModuleList([Normalize(num_hidden) for _ in range(2)]) self.W1 = nn.Linear(num_in, num_hidden, bias=True) self.W2 = nn.Linear(num_hidden, num_hidden, bias=True) self.W3 = nn.Linear(num_hidden, num_hidden, bias=True) self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
[docs] def forward(self, h_E, h_EV, E_idx, mask_E=None, mask_attend=None): """ Apply one MPNN update on edges in a TERM graph Args ---- h_E: torch.Tensor Edge features in kNN dense form Shape: n_batch x n_terms x n_nodes x k x n_hidden h_EV: torch.Tensor 'Neighbor' edge features, or all edges which share a 'central residue' with that edge, as well as the node features for both nodes that compose that edge. Shape: n_batch x n_terms x n_nodes x k x n_in mask_E : torch.ByteTensor or None Mask for message-passing regarding TERM edges Shape : n_batch x n_terms x n_nodes mask_attend: torch.ByteTensor or None Mask for message-passing regarding 'neighbor' edges Shape: n_batch x n_terms x n_nodes x k Returns ------- h_E: torch.Tensor Updated edge embeddings Shape: n_batch x n_terms x n_nodes x k x n_hidden """ dh = self.W3(F.relu(self.W2(F.relu(self.W1(h_EV))))) dh = merge_duplicate_term_edges(dh, E_idx) if mask_attend is not None: dh = mask_attend.unsqueeze(-1) * dh h_E = self.norm[0](h_E + self.dropout(dh)) # Position-wise feedforward dh = self.dense(h_E) h_E = self.norm[1](h_E + self.dropout(dh)) if mask_E is not None: mask_E = mask_E.unsqueeze(-1) h_E = mask_E * h_E return h_E
[docs]class TERMGraphTransformerEncoder(nn.Module): """ TERM Graph Transformer Encoder Alternating node and edge update layers to update the represenation of TERM graphs Attributes ---------- W_v : nn.Linear Embedding layer for nodes W_e : nn.Linear Embedding layer for edges node_encoder : nn.ModuleList of TERMNodeTransformerLayer or TERMNodeMPNNLayer Update layers for nodes edge_encoder : nn.ModuleList of TERMEdgeTransformerLayer or TERMEdgeMPNNLayer Update layers for edges W_out : nn.Linear Output layer """
[docs] def __init__(self, hparams): """ Args ---- hparams : dict Dictionary of model hparams (see :code:`~/scripts/models/train/default_hparams.json` for more info) """ super().__init__() self.hparams = hparams node_features = hparams['term_hidden_dim'] edge_features = hparams['term_hidden_dim'] hidden_dim = hparams['term_hidden_dim'] num_heads = hparams['term_heads'] dropout = hparams['transformer_dropout'] num_encoder_layers = hparams['term_layers'] # Hyperparameters self.node_features = node_features self.edge_features = edge_features self.input_dim = hidden_dim self.hidden_dim = hidden_dim self.output_dim = hidden_dim # Embedding layers self.W_v = nn.Linear(node_features, hidden_dim, bias=True) self.W_e = nn.Linear(edge_features, hidden_dim, bias=True) edge_layer = TERMEdgeTransformerLayer if not hparams['term_use_mpnn'] else TERMEdgeMPNNLayer node_layer = TERMNodeTransformerLayer if not hparams['term_use_mpnn'] else TERMNodeMPNNLayer # Encoder layers self.edge_encoder = nn.ModuleList([ edge_layer(hidden_dim, hidden_dim * 3 + (2 * hidden_dim if hparams['contact_idx'] else 0), dropout=dropout) for _ in range(num_encoder_layers) ]) self.node_encoder = nn.ModuleList([ node_layer(hidden_dim, num_in=hidden_dim * 2 + (2 * hidden_dim if hparams['contact_idx'] else 0), num_heads=num_heads, dropout=dropout) for _ in range(num_encoder_layers) ]) self.W_out = nn.Linear(hidden_dim, hidden_dim, bias=True) # Initialization for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p)
[docs] def forward(self, V, E, E_idx, mask, contact_idx=None): """ Refine TERM graph representations Args ---- V : torch.Tensor Node embeddings Shape: n_batches x n_terms x max_term_len x n_hidden E : torch.Tensor Edge embeddings in kNN dense form Shape: n_batches x n_terms x max_term_len x max_term_len x n_hidden E_idx : torch.LongTensor Edge indices Shape: n_batches x n_terms x max_term_len x max_term_len mask : torch.ByteTensor Mask for TERM resides Shape: n_batches x n_terms x max_term_len contact_idx : torch.Tensor Embedded contact indices Shape: n_batches x n_terms x max_term_len x n_hidden Returns ------- h_V : torch.Tensor TERM node embeddings h_E : torch.Tensor TERM edge embeddings """ h_V = self.W_v(V) h_E = self.W_e(E) # Encoder is unmasked self-attention mask_attend = gather_term_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) mask_attend = mask.unsqueeze(-1) * mask_attend for edge_layer, node_layer in zip(self.edge_encoder, self.node_encoder): h_EV_edges = cat_term_edge_endpoints(h_E, h_V, E_idx) if self.hparams['contact_idx']: h_EV_edges = cat_term_edge_endpoints(h_EV_edges, contact_idx, E_idx) h_E = edge_layer(h_E, h_EV_edges, E_idx, mask_E=mask_attend, mask_attend=mask_attend) if self.hparams['contact_idx']: h_EI = cat_term_edge_endpoints(h_E, contact_idx, E_idx) else: h_EI = h_E h_EV_nodes = cat_term_neighbors_nodes(h_V, h_EI, E_idx) h_V = node_layer(h_V, h_EV_nodes, mask_V=mask, mask_attend=mask_attend) h_E = self.W_out(h_E) h_E = merge_duplicate_term_edges(h_E, E_idx) return h_V, h_E