terminator.models.layers.energies.s2s.PairEnergies

class terminator.models.layers.energies.s2s.PairEnergies(hparams)[source]

Bases: Module

GNN Potts Model Encoder

Variables:
  • dev (str) – Device representing where the model is held

  • hparams (dict) – Dictionary of parameter settings (see terminator/utils/model/default_hparams.py)

  • features (MultiChainProteinFeatures) – Module that featurizes a protein backbone (including multimeric proteins)

  • W_v (nn.Linear) – Embedding layer for incoming TERM node embeddings

  • W_e (nn.Linear) – Embedding layer for incoming TERM edge embeddings

  • edge_encoder (nn.ModuleList of EdgeTransformerLayer or EdgeMPNNLayer) – Edge graph update layers

  • node_encoder (nn.ModuleList of NodeTransformerLayer or NodeMPNNLayer) – Node graph update layers

  • W_out (nn.Linear) – Output layer that projects edge embeddings to proper output dimensionality

  • W_proj (nn.Linear (optional)) – Output layer that projects node embeddings to proper output dimensionality. Enabled when hparams["node_self_sub"]=True

__init__(hparams)[source]

Graph labeling network

Methods

__init__(hparams)

Graph labeling network

forward(V_embed, E_embed, X, x_mask, chain_idx)

Create kNN etab from backbone and TERM features, then project to proper output dimensionality.

Attributes

T_destination

alias of TypeVar('T_destination', bound=Mapping[str, Tensor])

dump_patches

This allows better BC support for load_state_dict().

forward(V_embed, E_embed, X, x_mask, chain_idx)[source]

Create kNN etab from backbone and TERM features, then project to proper output dimensionality.

Parameters:
  • V_embed (torch.Tensor or None) – TERM node embeddings. None only accepted if hparams['energies_input_dim']=0. Shape: n_batch x n_res x n_hidden

  • E_embed (torch.Tensor or None) – TERM edge embeddings. None only accepted if hparams['energies_input_dim']=0. Shape : n_batch x n_res x n_res x n_hidden

  • X (torch.Tensor) – Backbone coordinates Shape: n_batch x n_res x 4 x 3

  • x_mask (torch.ByteTensor) – Mask for X. Shape: n_batch x n_res

  • chain_idx (torch.LongTensor) – Indices such that each chain is assigned a unique integer and each residue in that chain is assigned that integer. Shape: n_batch x n_res

Returns:

  • etab (torch.Tensor) – Energy table in kNN dense form Shape: n_batch x n_res x k x n_hidden

  • E_idx (torch.LongTensor) – Edge index for etab Shape: n_batch x n_res x k