terminator.models.layers.term.graph.s2s.TERMGraphTransformerEncoder¶
- class terminator.models.layers.term.graph.s2s.TERMGraphTransformerEncoder(hparams)[source]¶
Bases:
ModuleTERM Graph Transformer Encoder
Alternating node and edge update layers to update the represenation of TERM graphs
- Variables:
W_v (nn.Linear) – Embedding layer for nodes
W_e (nn.Linear) – Embedding layer for edges
node_encoder (nn.ModuleList of TERMNodeTransformerLayer or TERMNodeMPNNLayer) – Update layers for nodes
edge_encoder (nn.ModuleList of TERMEdgeTransformerLayer or TERMEdgeMPNNLayer) – Update layers for edges
W_out (nn.Linear) – Output layer
- __init__(hparams)[source]¶
- Parameters:
hparams (dict) – Dictionary of model hparams (see
~/scripts/models/train/default_hparams.jsonfor more info)
Methods
__init__(hparams)- Parameters:
hparams (dict) -- Dictionary of model hparams (see
~/scripts/models/train/default_hparams.jsonfor more info)
forward(V, E, E_idx, mask[, contact_idx])Refine TERM graph representations
Attributes
T_destinationalias of TypeVar('T_destination', bound=
Mapping[str,Tensor])dump_patchesThis allows better BC support for
load_state_dict().- forward(V, E, E_idx, mask, contact_idx=None)[source]¶
Refine TERM graph representations
- Parameters:
V (torch.Tensor) – Node embeddings Shape: n_batches x n_terms x max_term_len x n_hidden
E (torch.Tensor) – Edge embeddings in kNN dense form Shape: n_batches x n_terms x max_term_len x max_term_len x n_hidden
E_idx (torch.LongTensor) – Edge indices Shape: n_batches x n_terms x max_term_len x max_term_len
mask (torch.ByteTensor) – Mask for TERM resides Shape: n_batches x n_terms x max_term_len
contact_idx (torch.Tensor) – Embedded contact indices Shape: n_batches x n_terms x max_term_len x n_hidden
- Returns:
h_V (torch.Tensor) – TERM node embeddings
h_E (torch.Tensor) – TERM edge embeddings