terminator.models.layers.condense.CondenseTERM

class terminator.models.layers.condense.CondenseTERM(hparams, device='cuda:0')[source]

Bases: Module

TERM Information Condensor

Condense TERM matches and aggregate them together to form a full structure embedding

Variables:
  • embedding (ResidueFeatures) – Feature embedding module for TERM match residues

  • edge_features (EdgeFeatures) – Feature embedding module for TERM match residue interactions

  • matches (Conv1DResNet, TERMMatchTransformerEncoder, or None) – Matches Condensor (reduce the matches into a singular embedding per TERM residue)

  • W_ppoe (nn.Linear) – Linear layer for target structural features (e.g. featurized torsion angles, RMSD, environment values)

  • term_mpnn (TERMGraphTransformerEncoder) – TERM MPNN (refine TERM graph embeddings)

  • cie (ContactIndexEncoding, present when hparams['contact_idx']=True) – Sinusoidal encoder for contact indices

  • W_e (W_v,) – Modules to linearize TERM MPNN

__init__(hparams, device='cuda:0')[source]
Parameters:
  • hparams (dict) – Dictionary of model hparams (see ~/scripts/models/train/default_hparams.json for more info)

  • device (str, default=’cuda:0’) – What device to place the module on

Methods

__init__(hparams[, device])

Parameters:
  • hparams (dict) -- Dictionary of model hparams (see ~/scripts/models/train/default_hparams.json for more info)

forward(data, max_seq_len)

Convert input TERM data into a full structure representation

Attributes

T_destination

alias of TypeVar('T_destination', bound=Mapping[str, Tensor])

dump_patches

This allows better BC support for load_state_dict().

_agg_nodes(node_embeddings, batched_focuses, seq_lens, n_batches, max_seq_len)[source]

Fuse together TERM match residues so that every residue has one embedding.

Parameters:
  • node_embeddings (torch.Tensor) – TERM residue embeddings Shape: n_batches x n_terms x max_term_len x n_hidden

  • batched_focuses (torch.LongTensor) – Indices for which full-structure residue corresponds to the TERM match residue Shape: n_batches x n_terms x max_term_len

  • seq_lens (list of int) – Protein lengths in the batch

  • n_batches (int) – Number of batches

  • max_seq_len (int) – Maximum length of proteins in the batch

Returns:

aggregate – Residue embeddings derived from TERM data Shape: n_batches x max_seq_len x n_hidden

Return type:

torch.Tensor

_edges(embeddings, features, X, term_lens, batched_focuses, batchify_src_key_mask)[source]

Compute edge embeddings for TERMs

TODO: check shapes

Parameters:
  • embeddings (torch.Tensor, conditionally used) – Featurized matches Shape: TODO

  • features (torch.Tensor) – TERM match residue features (e.g. sinusoidally embedded torsion angles, rmsd, environment value) RMSD should be at index 7. Shape: TODO

  • X (torch.LongTensor, conditionally used) – Raw TERM match residue identities Shape: n_batches x n_matches x sum_term_len

  • term_lens (list of (list of int)) – Length of TERMs per protein

  • batched_focuses (torch.LongTensor) – Sequence position indices for TERM residues, batched by TERM Shape: TODO

  • batchify_src_key_mask (torch.ByteTensor) – Mask for TERM residues, batched by TERM Shape: TODO

Returns:

  • edge_features (torch.Tensor) – TERM edge features Shape: n_batches x n_terms x max_term_len x max_term_len x n_hidden

  • batch_rel_E_idx (torch.LongTensor) – Edge indices within a TERM Shape: n_batches x n_terms x max_term_len x max_term_len x n_hidden

  • batch_abs_E_idx (torch.LongTensor) – Edge indices relative to the target structure Shape: n_batches x n_terms x max_term_len x max_term_len x n_hidden

_matches(embeddings, ppoe, focuses, src_key_mask)[source]

Extract singleton statistics from matches using MatchesCondensor

Parameters:
  • embeddings (torch.Tensor) – Embedded match features Shape: TODO

  • ppoe (torch.Tensor) – Target structure \(\phi, \psi, \omega\), and environment value Shape: n_batch x seq_len x 4

  • focuses (torch.LongTensor) – Integer indices corresponding to embeddings which specifies what residue in the target structure that set of matches corresponds to Shape: TODO

Returns:

condensed_matches – The condensed matches, such that each term residue has one vector associated with it

Return type:

torch.Tensor

_term_mpnn(batchify_terms, edge_features, batch_rel_E_idx, src_key_mask, term_lens=None, contact_idx=None)[source]

Run TERM MPNN to refine graph embeddings

Parameters:
  • batchify_terms (torch.Tensor) – TERM residue node features Shape: n_batches x n_terms x max_term_len x n_hidden

  • edge_features (torch.Tensor) – TERM residue interaction features Shape: n_batches x n_terms x max_term_len x max_term_len x n_hidden

  • batch_rel_E_idx (torch.LongTensor) – Edge indices local to each TERM graph Shape: n_batches x n_terms x max_term_len x max_term_len

  • src_key_mask (torch.ByteTensor) – Mask for TERM residues Shape: n_batches x sum_term_len

  • term_lens (list of (list of int)) – Length of TERMs per protein

  • contact_idx (torch.Tensor) – Contact indices per TERM residue Shape: n_batches x sum_term_len

Returns:

  • node_embeddings (torch.Tensor) – Updated TERM residues embeddings Shape: n_batches x n_terms x max_term_len x n_hidden

  • edge_embeddings (torch.Tensor) – Updated TERM residue interaction embeddings Shape: n_batches x n_terms x max_term_len x max_term_len x n_hidden

forward(data, max_seq_len)[source]

Convert input TERM data into a full structure representation

Parameters:
  • data (dict) – Input data dictionary. See ~/terminator/data/data.py for more information.

  • max_seq_len (int) – Length of the largest protein in the input data

Returns:

  • agg_nodes (torch.Tensor) – Structure node embedding Shape: n_batch x max_seq_len x n_hidden

  • agg_edges (torch.Tensor) – Structure edge embeddings Shape: n_batch x max_seq_len x max_seq_len x n_hidden