terminator.models.layers.graph_features.MultiChainProteinFeatures¶
- class terminator.models.layers.graph_features.MultiChainProteinFeatures(edge_features, node_features, num_positional_embeddings=16, num_rbf=16, top_k=30, features_type='full', augment_eps=0.0, dropout=0.1)[source]¶
Bases:
ProteinFeaturesProtein backbone featurization which accounts for differences between inter-chain and intra-chain interactions.
- Variables:
embeddings (IndexDiffEncoding) – Module to generate differential positional embeddings for edges
dropout (nn.Dropout) – Dropout module
edge_embeddings (node_embeddings,) – Embedding layers for nodes and edges
norm_edges (norm_nodes,) – Normalization layers for node and edge features
- __init__(edge_features, node_features, num_positional_embeddings=16, num_rbf=16, top_k=30, features_type='full', augment_eps=0.0, dropout=0.1)[source]¶
Extract protein features
Methods
__init__(edge_features, node_features[, ...])Extract protein features
forward(X, chain_idx, mask)Featurize coordinates as an attributed graph
Attributes
T_destinationalias of TypeVar('T_destination', bound=
Mapping[str,Tensor])dump_patchesThis allows better BC support for
load_state_dict().- forward(X, chain_idx, mask)[source]¶
Featurize coordinates as an attributed graph
- Parameters:
X (torch.Tensor) – Backbone coordinates Shape: n_batch x seq_len x 4 x 3
chain_idx (torch.LongTensor) – Indices for residues such that each chain is assigned a unique integer and each residue in that chain is assigned that integer Shape: n_batches x seq_len
mask (torch.ByteTensor) – Mask for residues Shape: n_batch x seq_len
- Returns:
V (torch.Tensor) – Node embeddings Shape: n_batches x seq_len x n_hidden
E (torch.Tensor) – Edge embeddings in kNN dense form Shape: n_batches x seq_len x k x n_hidden
E_idx (torch.LongTensor) – Edge indices Shape: n_batches x seq_len x k x n_hidden