terminator.models.TERMinator.TERMinator

class terminator.models.TERMinator.TERMinator(hparams, device='cuda:0')[source]

Bases: Module

TERMinator model for multichain proteins

Variables:
  • dev (str) – Device representing where the model is held

  • hparams (dict) – Dictionary of parameter settings (see terminator/utils/model/default_hparams.py)

  • bot (CondenseTERM) – TERM information condenser network

  • top (PairEnergies (or appropriate variant thereof)) – GNN Potts Model Encoder network

__init__(hparams, device='cuda:0')[source]

Initializes TERMinator according to given parameters.

Parameters:
  • hparams (dict) – Dictionary of parameter settings (see terminator/utils/model/default_hparams.py)

  • device (str) – Device to place model on

Methods

__init__(hparams[, device])

Initializes TERMinator according to given parameters.

forward(data, max_seq_len)

Compute the Potts model parameters for the structure

Attributes

T_destination

alias of TypeVar('T_destination', bound=Mapping[str, Tensor])

dump_patches

This allows better BC support for load_state_dict().

_from_gvp_outputs(h_E, E_idx, seq_lens, max_seq_len)[source]

Convert outputs of GVP models to Ingraham style outputs

Parameters:
  • h_E (torch.Tensor) – Outputted Potts Model in Jing format

  • E_idx (torch.Tensor) – Edge index matrix in Ingraham format (kNN sparse)

  • seq_lens (np.ndarray (int)) – Sequence lens of proteins in batch

  • max_seq_len (int) – Max sequence length of proteins in batch

Returns:

  • etab (torch.Tensor) – Potts Model in Ingraham Format

  • E_idx (torch.LongTensor) – Edge index matrix in Ingraham format (kNN sparse)

_to_gvp_input(node_embeddings, edge_embeddings, data)[source]

Convert Ingraham-style inputs to Jing-style inputs for use in GVP models

Parameters:
  • node_embeddings (torch.Tensor or None) – Node embeddings at the structure level, outputted by the TERM Info Condensor. None if running in TERMless mode Shape: n_batch x max_seq_len x tic_n_hidden

  • edge_embeddings (torch.Tensor or None) – Edge embedings at the structure level, outputted by the TERM Info Condensor. None if running in TERMless mode Shape: n_batch x max_seq_len x max_seq_len x tic_n_hidden

  • data (dict of torch.Tensor) – Overall input data dictionary. See forward for more info.

Returns:

  • h_V (torch.Tensor) – Node embeddings in Jing format

  • edge_idex (torch.LongTensor) – Edge index matrix in Jing format (sparse form)

  • h_E (torch.Tensor) – Edge embeddings in Jing format

  • E_idx (torch.LongTensor) – Edge index matrix in Ingraham format (kNN form)

forward(data, max_seq_len)[source]

Compute the Potts model parameters for the structure

Runs the full TERMinator network for prediction.

Parameters:

data (dict) – Contains the following keys:

msastorch.LongTensor

Integer encoding of sequence matches. Shape: n_batch x n_term_res x n_matches

featurestorch.FloatTensor

Featurization of match structural data. Shape: n_batch x n_term_res x n_matches x n_features(=9 by default)

seq_lensint np.ndarray

1D Array of batched sequence lengths. Shape: n_batch

focusestorch.LongTensor

Indices for TERM residues matches. Shape: n_batch x n_term_res

term_lensint np.ndarray

2D Array of batched TERM lengths. Shape: n_batch x n_terms

src_key_masktorch.ByteTensor

Mask for TERM residue positions padding. Shape: n_batch x n_term_res

Xtorch.FloatTensor

Raw coordinates of protein backbones. Shape: n_batch x n_res x 4 x 3

x_masktorch.ByteTensor

Mask for X. Shape: n_batch x n_res

sequencetorch.LongTensor

Integer encoding of ground truth native sequences. Shape: n_batch x n_res

max_seq_lenint

Max length of protein in the batch.

ppoetorch.FloatTensor

Featurization of target protein structural data. Shape: n_batch x n_res x n_features(=9 by default)

chain_idxtorch.LongTensor

Integers indices that designate ever residue to a chain. Shape: n_batch x n_res

contact_idxtorch.LongTensor

Integers representing contact indices across all TERM residues. Shape: n_batch x n_term_res

gvp_datalist of torch_geometric.data.Data

Vector and scalar featurizations of the backbone, as required by GVP

Returns:

  • etab (torch.FloatTensor) – Dense kNN representation of the energy table, with E_idx denotating which energies correspond to which edge. Shape: n_batch x n_res x k(=30 by default) x hparams['energies_output_dim'] (=400 by default)

  • E_idx (torch.LongTensor) – Indices representing edges in the kNN graph. Given node res_idx, the set of edges centered around that node are given by E_idx[b_idx][res_idx], with the i-th closest node given by E_idx[b_idx][res_idx][i]. Shape: n_batch x n_res x k(=30 by default)