terminator.models.TERMinator.TERMinator¶
- class terminator.models.TERMinator.TERMinator(hparams, device='cuda:0')[source]¶
Bases:
ModuleTERMinator model for multichain proteins
- Variables:
dev (str) – Device representing where the model is held
hparams (dict) – Dictionary of parameter settings (see
terminator/utils/model/default_hparams.py)bot (CondenseTERM) – TERM information condenser network
top (PairEnergies (or appropriate variant thereof)) – GNN Potts Model Encoder network
- __init__(hparams, device='cuda:0')[source]¶
Initializes TERMinator according to given parameters.
- Parameters:
hparams (dict) – Dictionary of parameter settings (see
terminator/utils/model/default_hparams.py)device (str) – Device to place model on
Methods
__init__(hparams[, device])Initializes TERMinator according to given parameters.
forward(data, max_seq_len)Compute the Potts model parameters for the structure
Attributes
T_destinationalias of TypeVar('T_destination', bound=
Mapping[str,Tensor])dump_patchesThis allows better BC support for
load_state_dict().- _from_gvp_outputs(h_E, E_idx, seq_lens, max_seq_len)[source]¶
Convert outputs of GVP models to Ingraham style outputs
- Parameters:
h_E (torch.Tensor) – Outputted Potts Model in Jing format
E_idx (torch.Tensor) – Edge index matrix in Ingraham format (kNN sparse)
seq_lens (np.ndarray (int)) – Sequence lens of proteins in batch
max_seq_len (int) – Max sequence length of proteins in batch
- Returns:
etab (torch.Tensor) – Potts Model in Ingraham Format
E_idx (torch.LongTensor) – Edge index matrix in Ingraham format (kNN sparse)
- _to_gvp_input(node_embeddings, edge_embeddings, data)[source]¶
Convert Ingraham-style inputs to Jing-style inputs for use in GVP models
- Parameters:
node_embeddings (torch.Tensor or None) – Node embeddings at the structure level, outputted by the TERM Info Condensor.
Noneif running in TERMless mode Shape: n_batch x max_seq_len x tic_n_hiddenedge_embeddings (torch.Tensor or None) – Edge embedings at the structure level, outputted by the TERM Info Condensor.
Noneif running in TERMless mode Shape: n_batch x max_seq_len x max_seq_len x tic_n_hiddendata (dict of torch.Tensor) – Overall input data dictionary. See
forwardfor more info.
- Returns:
h_V (torch.Tensor) – Node embeddings in Jing format
edge_idex (torch.LongTensor) – Edge index matrix in Jing format (sparse form)
h_E (torch.Tensor) – Edge embeddings in Jing format
E_idx (torch.LongTensor) – Edge index matrix in Ingraham format (kNN form)
- forward(data, max_seq_len)[source]¶
Compute the Potts model parameters for the structure
Runs the full TERMinator network for prediction.
- Parameters:
data (dict) – Contains the following keys:
- msastorch.LongTensor
Integer encoding of sequence matches. Shape: n_batch x n_term_res x n_matches
- featurestorch.FloatTensor
Featurization of match structural data. Shape: n_batch x n_term_res x n_matches x n_features(=9 by default)
- seq_lensint np.ndarray
1D Array of batched sequence lengths. Shape: n_batch
- focusestorch.LongTensor
Indices for TERM residues matches. Shape: n_batch x n_term_res
- term_lensint np.ndarray
2D Array of batched TERM lengths. Shape: n_batch x n_terms
- src_key_masktorch.ByteTensor
Mask for TERM residue positions padding. Shape: n_batch x n_term_res
- Xtorch.FloatTensor
Raw coordinates of protein backbones. Shape: n_batch x n_res x 4 x 3
- x_masktorch.ByteTensor
Mask for X. Shape: n_batch x n_res
- sequencetorch.LongTensor
Integer encoding of ground truth native sequences. Shape: n_batch x n_res
- max_seq_lenint
Max length of protein in the batch.
- ppoetorch.FloatTensor
Featurization of target protein structural data. Shape: n_batch x n_res x n_features(=9 by default)
- chain_idxtorch.LongTensor
Integers indices that designate ever residue to a chain. Shape: n_batch x n_res
- contact_idxtorch.LongTensor
Integers representing contact indices across all TERM residues. Shape: n_batch x n_term_res
- gvp_datalist of torch_geometric.data.Data
Vector and scalar featurizations of the backbone, as required by GVP
- Returns:
etab (torch.FloatTensor) – Dense kNN representation of the energy table, with
E_idxdenotating which energies correspond to which edge. Shape: n_batch x n_res x k(=30 by default) xhparams['energies_output_dim'](=400 by default)E_idx (torch.LongTensor) – Indices representing edges in the kNN graph. Given node res_idx, the set of edges centered around that node are given by
E_idx[b_idx][res_idx], with the i-th closest node given byE_idx[b_idx][res_idx][i]. Shape: n_batch x n_res x k(=30 by default)