terminator.models.layers.energies.s2s.AblatedPairEnergies

class terminator.models.layers.energies.s2s.AblatedPairEnergies(hparams)[source]

Bases: Module

Ablated GNN Potts Model Encoder

Variables:
  • dev (str) – Device representing where the model is held

  • hparams (dict) – Dictionary of parameter settings (see terminator/utils/model/default_hparams.py)

  • features (MultiChainProteinFeatures) – Module that featurizes a protein backbone (including multimeric proteins)

  • W (nn.Linear) – Output layer that projects edge embeddings to proper output dimensionality

__init__(hparams)[source]

Graph labeling network

Methods

__init__(hparams)

Graph labeling network

forward(V_embed, E_embed, X, x_mask, chain_idx)

Create kNN etab from TERM features, then project to proper output dimensionality.

Attributes

T_destination

alias of TypeVar('T_destination', bound=Mapping[str, Tensor])

dump_patches

This allows better BC support for load_state_dict().

forward(V_embed, E_embed, X, x_mask, chain_idx)[source]

Create kNN etab from TERM features, then project to proper output dimensionality.

Parameters:
  • V_embed (torch.Tensor) – TERM node embeddings Shape: n_batch x n_res x n_hidden

  • E_embed (torch.Tensor) – TERM edge embeddings Shape : n_batch x n_res x n_res x n_hidden

  • X (torch.Tensor) – Backbone coordinates Shape: n_batch x n_res x 4 x 3

  • x_mask (torch.ByteTensor) – Mask for X. Shape: n_batch x n_res

  • chain_idx (torch.LongTensor) – Indices such that each chain is assigned a unique integer and each residue in that chain is assigned that integer. Shape: n_batch x n_res

Returns:

  • etab (torch.Tensor) – Energy table in kNN dense form Shape: n_batch x n_res x k x n_hidden

  • E_idx (torch.LongTensor) – Edge index for etab Shape: n_batch x n_res x k