terminator.models.layers.graph_features.MultiChainProteinFeatures

class terminator.models.layers.graph_features.MultiChainProteinFeatures(edge_features, node_features, num_positional_embeddings=16, num_rbf=16, top_k=30, features_type='full', augment_eps=0.0, dropout=0.1)[source]

Bases: ProteinFeatures

Protein backbone featurization which accounts for differences between inter-chain and intra-chain interactions.

Variables:
  • embeddings (IndexDiffEncoding) – Module to generate differential positional embeddings for edges

  • dropout (nn.Dropout) – Dropout module

  • edge_embeddings (node_embeddings,) – Embedding layers for nodes and edges

  • norm_edges (norm_nodes,) – Normalization layers for node and edge features

__init__(edge_features, node_features, num_positional_embeddings=16, num_rbf=16, top_k=30, features_type='full', augment_eps=0.0, dropout=0.1)[source]

Extract protein features

Methods

__init__(edge_features, node_features[, ...])

Extract protein features

forward(X, chain_idx, mask)

Featurize coordinates as an attributed graph

Attributes

T_destination

alias of TypeVar('T_destination', bound=Mapping[str, Tensor])

dump_patches

This allows better BC support for load_state_dict().

forward(X, chain_idx, mask)[source]

Featurize coordinates as an attributed graph

Parameters:
  • X (torch.Tensor) – Backbone coordinates Shape: n_batch x seq_len x 4 x 3

  • chain_idx (torch.LongTensor) – Indices for residues such that each chain is assigned a unique integer and each residue in that chain is assigned that integer Shape: n_batches x seq_len

  • mask (torch.ByteTensor) – Mask for residues Shape: n_batch x seq_len

Returns:

  • V (torch.Tensor) – Node embeddings Shape: n_batches x seq_len x n_hidden

  • E (torch.Tensor) – Edge embeddings in kNN dense form Shape: n_batches x seq_len x k x n_hidden

  • E_idx (torch.LongTensor) – Edge indices Shape: n_batches x seq_len x k x n_hidden