terminator.models.layers.graph_features.IndexDiffEncoding

class terminator.models.layers.graph_features.IndexDiffEncoding(num_embeddings)[source]

Bases: Module

Module to generate differential positional encodings for multichain protein graph edges

Similar to ProteinFeatures, but zeros out features between interchain interactions

__init__(num_embeddings)[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Methods

__init__(num_embeddings)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(E_idx, chain_idx)

Generate directional differential positional encodings for edges

Attributes

T_destination

alias of TypeVar('T_destination', bound=Mapping[str, Tensor])

dump_patches

This allows better BC support for load_state_dict().

forward(E_idx, chain_idx)[source]

Generate directional differential positional encodings for edges

Parameters:
  • E_idx (torch.LongTensor) – Protein kNN edge indices Shape: n_batches x seq_len x k

  • chain_idx (torch.LongTensor) – Indices for residues such that each chain is assigned a unique integer and each residue in that chain is assigned that integer Shape: n_batches x seq_len

Returns:

E – Directional Diffential positional encodings for edges Shape: n_batches x seq_len x k x num_embeddings

Return type:

torch.Tensor