"""TERMinator models"""
import torch
import torch_geometric.data
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from .layers.condense import CondenseTERM
from .layers.energies.gvp import GVPPairEnergies
from .layers.energies.s2s import (AblatedPairEnergies, PairEnergies)
from .layers.utils import gather_edges, pad_sequence_12
# pylint: disable=no-member, not-callable
[docs]class TERMinator(nn.Module):
"""TERMinator model for multichain proteins
Attributes
----------
dev: str
Device representing where the model is held
hparams: dict
Dictionary of parameter settings (see :code:`terminator/utils/model/default_hparams.py`)
bot: CondenseTERM
TERM information condenser network
top: PairEnergies (or appropriate variant thereof)
GNN Potts Model Encoder network
"""
[docs] def __init__(self, hparams, device='cuda:0'):
"""
Initializes TERMinator according to given parameters.
Args
----
hparams : dict
Dictionary of parameter settings (see :code:`terminator/utils/model/default_hparams.py`)
device : str
Device to place model on
"""
super().__init__()
self.dev = device
self.hparams = hparams
if self.hparams["use_terms"]:
self.hparams['energies_input_dim'] = self.hparams['term_hidden_dim']
self.bot = CondenseTERM(hparams, device=self.dev)
else:
self.hparams['energies_input_dim'] = 0
if hparams['struct2seq_linear']:
self.top = AblatedPairEnergies(hparams).to(self.dev)
elif hparams['energies_gvp']:
self.top = GVPPairEnergies(hparams).to(self.dev)
else:
self.top = PairEnergies(hparams).to(self.dev)
if self.hparams['use_terms']:
print(f'TERM information condenser hidden dimensionality is {self.bot.hparams["term_hidden_dim"]}')
print(f'GNN Potts Model Encoder hidden dimensionality is {self.top.hparams["energies_hidden_dim"]}')
# Initialization
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
[docs] def _from_gvp_outputs(self, h_E, E_idx, seq_lens, max_seq_len):
""" Convert outputs of GVP models to Ingraham style outputs
Args
----
h_E : torch.Tensor
Outputted Potts Model in Jing format
E_idx : torch.Tensor
Edge index matrix in Ingraham format (kNN sparse)
seq_lens : np.ndarray (int)
Sequence lens of proteins in batch
max_seq_len : int
Max sequence length of proteins in batch
Returns
-------
etab : torch.Tensor
Potts Model in Ingraham Format
E_idx : torch.LongTensor
Edge index matrix in Ingraham format (kNN sparse)
"""
# convert gvp outputs to TERMinator format
h_E = h_E.view([
h_E.shape[0] // self.hparams['k_neighbors'], self.hparams['k_neighbors'],
self.hparams['energies_output_dim']
])
split_h_E = torch.split(h_E, seq_lens.tolist())
etab = pad_sequence_12(split_h_E)
#print(etab.shape, E_idx.shape)
# pad the difference if using DataParallel
padding_diff = max_seq_len - etab.shape[1]
if padding_diff > 0:
padding = torch.zeros(etab.shape[0], padding_diff, etab.shape[2], etab.shape[3], device=etab.device)
etab = torch.cat([etab, padding], dim=1)
padding = torch.zeros(etab.shape[0], padding_diff, etab.shape[2], device=etab.device).long()
E_idx = torch.cat([E_idx, padding], dim=1)
return etab, E_idx
[docs] def forward(self, data, max_seq_len):
"""Compute the Potts model parameters for the structure
Runs the full TERMinator network for prediction.
Args
----
data : dict
Contains the following keys:
msas : torch.LongTensor
Integer encoding of sequence matches.
Shape: n_batch x n_term_res x n_matches
features : torch.FloatTensor
Featurization of match structural data.
Shape: n_batch x n_term_res x n_matches x n_features(=9 by default)
seq_lens : int np.ndarray
1D Array of batched sequence lengths.
Shape: n_batch
focuses : torch.LongTensor
Indices for TERM residues matches.
Shape: n_batch x n_term_res
term_lens : int np.ndarray
2D Array of batched TERM lengths.
Shape: n_batch x n_terms
src_key_mask : torch.ByteTensor
Mask for TERM residue positions padding.
Shape: n_batch x n_term_res
X : torch.FloatTensor
Raw coordinates of protein backbones.
Shape: n_batch x n_res x 4 x 3
x_mask : torch.ByteTensor
Mask for X.
Shape: n_batch x n_res
sequence : torch.LongTensor
Integer encoding of ground truth native sequences.
Shape: n_batch x n_res
max_seq_len : int
Max length of protein in the batch.
ppoe : torch.FloatTensor
Featurization of target protein structural data.
Shape: n_batch x n_res x n_features(=9 by default)
chain_idx : torch.LongTensor
Integers indices that designate ever residue to a chain.
Shape: n_batch x n_res
contact_idx : torch.LongTensor
Integers representing contact indices across all TERM residues.
Shape: n_batch x n_term_res
gvp_data : list of torch_geometric.data.Data
Vector and scalar featurizations of the backbone, as required by GVP
Returns
-------
etab : torch.FloatTensor
Dense kNN representation of the energy table, with :code:`E_idx`
denotating which energies correspond to which edge.
Shape: n_batch x n_res x k(=30 by default) x :code:`hparams['energies_output_dim']` (=400 by default)
E_idx : torch.LongTensor
Indices representing edges in the kNN graph.
Given node `res_idx`, the set of edges centered around that node are
given by :code:`E_idx[b_idx][res_idx]`, with the `i`-th closest node given by
:code:`E_idx[b_idx][res_idx][i]`.
Shape: n_batch x n_res x k(=30 by default)
"""
if self.hparams['use_terms']:
node_embeddings, edge_embeddings = self.bot(data, max_seq_len)
else:
node_embeddings, edge_embeddings = None, None
if self.hparams['energies_gvp']:
h_V, edge_index, h_E, E_idx = self._to_gvp_input(node_embeddings, edge_embeddings, data)
h_E, edge_index = self.top(h_V, edge_index, h_E)
etab, E_idx = self._from_gvp_outputs(h_E, E_idx, data['seq_lens'], max_seq_len)
else:
etab, E_idx = self.top(node_embeddings, edge_embeddings, data['X'], data['x_mask'], data['chain_idx'])
if self.hparams['k_cutoff']:
k = E_idx.shape[-1]
k_cutoff = self.hparams['k_cutoff']
assert k > k_cutoff > 0, f"k_cutoff={k_cutoff} must be greater than k"
etab = etab[..., :k_cutoff, :]
E_idx = E_idx[..., :k_cutoff]
return etab, E_idx