Source code for terminator.models.layers.term.matches.attn
""" TERM Match Attention
This file includes modules which perform Attention to summarize the
information in TERM matches.
"""
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from terminator.models.layers.s2s_modules import (Normalize, PositionWiseFeedForward)
# pylint: disable=no-member
[docs]class TERMMatchAttention(nn.Module):
""" TERM Match Attention
A module with computes a node update using self-attention over
all neighboring TERM residues and the edges connecting them.
Attributes
----------
W_Q : nn.Linear
Projection matrix for querys
W_K : nn.Linear
Projection matrix for keys
W_V : nn.Linear
Projection matrix for values
W_O : nn.Linear
Output layer
"""
[docs] def __init__(self, hparams):
"""
Args
----
hparams : dict
Dictionary of model hparams (see :code:`~/scripts/models/train/default_hparams.json` for more info)
"""
super().__init__()
self.hparams = hparams
hdim = hparams['term_hidden_dim']
# Self-attention layers: {queries, keys, values, output}
self.W_Q = nn.Linear(hdim, hdim, bias=False)
self.W_K = nn.Linear(hdim * 2, hdim, bias=False)
self.W_V = nn.Linear(hdim * 2, hdim, bias=False)
self.W_O = nn.Linear(hdim, hdim, bias=False)
[docs] def _masked_softmax(self, attend_logits, mask_attend, dim=-1):
""" Numerically stable masked softmax """
negative_inf = np.finfo(np.float32).min
mask_attn_dev = mask_attend.device
attend_logits = torch.where(mask_attend > 0, attend_logits, torch.tensor(negative_inf).to(mask_attn_dev))
attend = F.softmax(attend_logits, dim)
attend = mask_attend.float() * attend
return attend
[docs] def forward(self, h_V, h_T, mask_attend=None):
""" Self-attention update over residues in TERM matches
Args
----
h_V : torch.Tensor
TERM match residues
Shape: n_batch x sum_term_len x n_matches x n_hidden
h_T : torch.Tensor
Embedded structural features of target residue
Shape: n_batch x sum_term_len x n_hidden
mask_attend : torch.ByteTensor or None
Mask for attention
Shape: n_batch x sum_term_len # TODO: check shape
Returns
-------
src_update : torch.Tensor
TERM matches embedding update
Shape: n_batch x sum_term_len x n_matches x n_hidden
"""
n_batches, sum_term_len, n_matches = h_V.shape[:3]
# append h_T onto h_V to form h_VT
h_T_expand = h_T.unsqueeze(-2).expand(h_V.shape)
h_VT = torch.cat([h_V, h_T_expand], dim=-1)
query = h_V
key = h_VT
value = h_VT
n_heads = self.hparams['matches_num_heads']
num_hidden = self.hparams['term_hidden_dim']
assert num_hidden % n_heads == 0
d = num_hidden // n_heads
Q = self.W_Q(query).view([n_batches, sum_term_len, n_matches, n_heads, d]).transpose(2, 3)
K = self.W_K(key).view([n_batches, sum_term_len, n_matches, n_heads, d]).transpose(2, 3)
V = self.W_V(value).view([n_batches, sum_term_len, n_matches, n_heads, d]).transpose(2, 3)
attend_logits = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d)
if mask_attend is not None:
# we need to reshape the src key mask for residue-residue attention
# expand to num_heads
mask = mask_attend.unsqueeze(2).expand(-1, -1, n_heads, -1).unsqueeze(-1).float()
mask_t = mask.transpose(-2, -1)
# perform outer product
mask = mask @ mask_t
mask = mask.bool()
# Masked softmax
attend = self._masked_softmax(attend_logits, mask)
else:
attend = F.softmax(attend_logits, -1)
src_update = torch.matmul(attend, V).transpose(2, 3).contiguous()
src_update = src_update.view([n_batches, sum_term_len, n_matches, num_hidden])
src_update = self.W_O(src_update)
return src_update
[docs]class TERMMatchTransformerLayer(nn.Module):
""" TERM Match Transformer Layer
A TERM Match Transformer Layer that updates match embeddings via TERMMatchATtention
Attributes
----------
attention: TERMMatchAttention
Transformer Attention mechanism
dense: PositionWiseFeedForward
Transformer position-wise FFN
"""
[docs] def __init__(self, hparams):
"""
Args
----
hparams : dict
Dictionary of model hparams (see :code:`~/scripts/models/train/default_hparams.json` for more info)
"""
super().__init__()
self.hparams = hparams
self.dropout = nn.Dropout(hparams['transformer_dropout'])
hdim = hparams['term_hidden_dim']
self.norm = nn.ModuleList([Normalize(hdim) for _ in range(2)])
self.attention = TERMMatchAttention(hparams=self.hparams)
self.dense = PositionWiseFeedForward(hdim, hdim * 4)
[docs] def forward(self, src, target, src_mask=None, mask_attend=None, checkpoint=False):
""" Apply one Transformer update to TERM matches
Args
----
src: torch.Tensor
TERM Match features
Shape: n_batch x sum_term_len x n_matches x n_hidden
target: torch.Tensor
Embedded structural features per TERM residue of target structure
Shape: n_batch x sum_term_len x n_matches x n_hidden
src_mask : torch.ByteTensor or None
Mask for attention regarding TERM residues
Shape : n_batch x sum_term_len
mask_attend: torch.ByteTensor or None
Mask for attention regarding matches
Shape: n_batch x sum_term_len # TODO: check shape
checkpoint : bool, default=False
Whether to use gradient checkpointing to reduce memory usage
Returns
-------
src: torch.Tensor
Updated match embeddings
Shape: n_batch x sum_term_len x n_matches x n_hidden
"""
# Self-attention
if checkpoint:
dsrc = torch.utils.checkpoint.checkpoint(self.attention, src, target, mask_attend)
else:
dsrc = self.attention(src, target, mask_attend=mask_attend)
src = self.norm[0](src + self.dropout(dsrc))
# Position-wise feedforward
if checkpoint:
dsrc = torch.utils.checkpoint.checkpoint(self.dense, src)
else:
dsrc = self.dense(src)
src = self.norm[1](src + self.dropout(dsrc))
if src_mask is not None:
src_mask = src_mask.unsqueeze(-1)
src = src_mask * src
return src
[docs]class TERMMatchTransformerEncoder(nn.Module):
""" TERM Match Transformer Encoder
A Transformer which uses a pool token to summarize the contents of TERM matches
Attributes
----------
W_v : nn.Linear
Embedding layer for matches
W_t : nn.Linear
Embedding layer for target structure information
W_pool: nn.Linear
Embedding layer for pool token
encoder_layers : nn.ModuleList of TERMMatchTransformerLayer
Transformer layers for matches
W_out : nn.Linear
Output layer
pool_token_init : nn.Parameter
The embedding for the pool token used to gather information,
reminiscent of [CLS] tokens in BERT
"""
[docs] def __init__(self, hparams):
"""
Args
----
hparams : dict
Dictionary of model hparams (see :code:`~/scripts/models/train/default_hparams.json` for more info)
"""
super().__init__()
self.hparams = hparams
# Hyperparameters
hidden_dim = hparams['term_hidden_dim']
self.hidden_dim = hidden_dim
num_encoder_layers = hparams['matches_layers']
# Embedding layers
self.W_v = nn.Linear(hidden_dim, hidden_dim, bias=True)
self.W_t = nn.Linear(hidden_dim, hidden_dim, bias=True)
self.W_pool = nn.Linear(hidden_dim * 2, hidden_dim, bias=True)
layer = TERMMatchTransformerLayer
# Encoder layers
self.encoder_layers = nn.ModuleList([layer(hparams) for _ in range(num_encoder_layers)])
self.W_out = nn.Linear(hidden_dim, hidden_dim, bias=True)
# lets try [CLS]-style pooling
pool_token_init = torch.zeros(1, hidden_dim)
torch.nn.init.xavier_uniform_(pool_token_init)
self.pool_token = nn.Parameter(pool_token_init, requires_grad=True)
[docs] def forward(self, V, target, mask):
""" Summarize TERM matches
Args
----
V : torch.Tensor
TERM Match embedding
Shape: n_batches x sum_term_len x n_matches x n_hidden
target : torch.Tensor
Embedded structural information of target per TERM residue
Shape: n_batches x sum_term_len x n_hidden
mask : torch.ByteTensor
Mask for TERM resides
Shape: n_batches x sum_term_len
Returns
-------
torch.Tensor
Summarized TERM matches
Shape: n_batches x sum_term_len x n_hidden
"""
n_batches, sum_term_len = V.shape[:2]
# embed each copy of the pool token with some information about the target ppoe
pool = self.pool_token.view([1, 1, self.hidden_dim]).expand(n_batches, sum_term_len, -1)
pool = torch.cat([pool, target], dim=-1)
pool = self.W_pool(pool)
pool = pool.unsqueeze(-2)
V = torch.cat([pool, V], dim=-2)
h_V = self.W_v(V)
h_T = self.W_t(target)
# Encoder is unmasked self-attention
for _, layer in enumerate(self.encoder_layers):
h_V = layer(h_V, h_T, mask.unsqueeze(-1).float(), checkpoint=self.hparams['gradient_checkpointing'])
h_V = self.W_out(h_V)
return h_V[:, :, 0, :]