terminator.models.layers.condense.CondenseTERM¶
- class terminator.models.layers.condense.CondenseTERM(hparams, device='cuda:0')[source]¶
Bases:
ModuleTERM Information Condensor
Condense TERM matches and aggregate them together to form a full structure embedding
- Variables:
embedding (ResidueFeatures) – Feature embedding module for TERM match residues
edge_features (EdgeFeatures) – Feature embedding module for TERM match residue interactions
matches (Conv1DResNet, TERMMatchTransformerEncoder, or None) – Matches Condensor (reduce the matches into a singular embedding per TERM residue)
W_ppoe (nn.Linear) – Linear layer for target structural features (e.g. featurized torsion angles, RMSD, environment values)
term_mpnn (TERMGraphTransformerEncoder) – TERM MPNN (refine TERM graph embeddings)
cie (ContactIndexEncoding, present when
hparams['contact_idx']=True) – Sinusoidal encoder for contact indicesW_e (W_v,) – Modules to linearize TERM MPNN
- __init__(hparams, device='cuda:0')[source]¶
- Parameters:
hparams (dict) – Dictionary of model hparams (see
~/scripts/models/train/default_hparams.jsonfor more info)device (str, default=’cuda:0’) – What device to place the module on
Methods
__init__(hparams[, device])- Parameters:
hparams (dict) -- Dictionary of model hparams (see
~/scripts/models/train/default_hparams.jsonfor more info)
forward(data, max_seq_len)Convert input TERM data into a full structure representation
Attributes
T_destinationalias of TypeVar('T_destination', bound=
Mapping[str,Tensor])dump_patchesThis allows better BC support for
load_state_dict().- _agg_nodes(node_embeddings, batched_focuses, seq_lens, n_batches, max_seq_len)[source]¶
Fuse together TERM match residues so that every residue has one embedding.
- Parameters:
node_embeddings (torch.Tensor) – TERM residue embeddings Shape: n_batches x n_terms x max_term_len x n_hidden
batched_focuses (torch.LongTensor) – Indices for which full-structure residue corresponds to the TERM match residue Shape: n_batches x n_terms x max_term_len
seq_lens (list of int) – Protein lengths in the batch
n_batches (int) – Number of batches
max_seq_len (int) – Maximum length of proteins in the batch
- Returns:
aggregate – Residue embeddings derived from TERM data Shape: n_batches x max_seq_len x n_hidden
- Return type:
torch.Tensor
- _edges(embeddings, features, X, term_lens, batched_focuses, batchify_src_key_mask)[source]¶
Compute edge embeddings for TERMs
TODO: check shapes
- Parameters:
embeddings (torch.Tensor, conditionally used) – Featurized matches Shape: TODO
features (torch.Tensor) – TERM match residue features (e.g. sinusoidally embedded torsion angles, rmsd, environment value) RMSD should be at index 7. Shape: TODO
X (torch.LongTensor, conditionally used) – Raw TERM match residue identities Shape: n_batches x n_matches x sum_term_len
term_lens (list of (list of int)) – Length of TERMs per protein
batched_focuses (torch.LongTensor) – Sequence position indices for TERM residues, batched by TERM Shape: TODO
batchify_src_key_mask (torch.ByteTensor) – Mask for TERM residues, batched by TERM Shape: TODO
- Returns:
edge_features (torch.Tensor) – TERM edge features Shape: n_batches x n_terms x max_term_len x max_term_len x n_hidden
batch_rel_E_idx (torch.LongTensor) – Edge indices within a TERM Shape: n_batches x n_terms x max_term_len x max_term_len x n_hidden
batch_abs_E_idx (torch.LongTensor) – Edge indices relative to the target structure Shape: n_batches x n_terms x max_term_len x max_term_len x n_hidden
- _matches(embeddings, ppoe, focuses, src_key_mask)[source]¶
Extract singleton statistics from matches using MatchesCondensor
- Parameters:
embeddings (torch.Tensor) – Embedded match features Shape: TODO
ppoe (torch.Tensor) – Target structure \(\phi, \psi, \omega\), and environment value Shape: n_batch x seq_len x 4
focuses (torch.LongTensor) – Integer indices corresponding to embeddings which specifies what residue in the target structure that set of matches corresponds to Shape: TODO
- Returns:
condensed_matches – The condensed matches, such that each term residue has one vector associated with it
- Return type:
torch.Tensor
- _term_mpnn(batchify_terms, edge_features, batch_rel_E_idx, src_key_mask, term_lens=None, contact_idx=None)[source]¶
Run TERM MPNN to refine graph embeddings
- Parameters:
batchify_terms (torch.Tensor) – TERM residue node features Shape: n_batches x n_terms x max_term_len x n_hidden
edge_features (torch.Tensor) – TERM residue interaction features Shape: n_batches x n_terms x max_term_len x max_term_len x n_hidden
batch_rel_E_idx (torch.LongTensor) – Edge indices local to each TERM graph Shape: n_batches x n_terms x max_term_len x max_term_len
src_key_mask (torch.ByteTensor) – Mask for TERM residues Shape: n_batches x sum_term_len
term_lens (list of (list of int)) – Length of TERMs per protein
contact_idx (torch.Tensor) – Contact indices per TERM residue Shape: n_batches x sum_term_len
- Returns:
node_embeddings (torch.Tensor) – Updated TERM residues embeddings Shape: n_batches x n_terms x max_term_len x n_hidden
edge_embeddings (torch.Tensor) – Updated TERM residue interaction embeddings Shape: n_batches x n_terms x max_term_len x max_term_len x n_hidden
- forward(data, max_seq_len)[source]¶
Convert input TERM data into a full structure representation
- Parameters:
data (dict) – Input data dictionary. See
~/terminator/data/data.pyfor more information.max_seq_len (int) – Length of the largest protein in the input data
- Returns:
agg_nodes (torch.Tensor) – Structure node embedding Shape: n_batch x max_seq_len x n_hidden
agg_edges (torch.Tensor) – Structure edge embeddings Shape: n_batch x max_seq_len x max_seq_len x n_hidden