terminator.models.layers.term.graph.s2s.TERMGraphTransformerEncoder

class terminator.models.layers.term.graph.s2s.TERMGraphTransformerEncoder(hparams)[source]

Bases: Module

TERM Graph Transformer Encoder

Alternating node and edge update layers to update the represenation of TERM graphs

Variables:
  • W_v (nn.Linear) – Embedding layer for nodes

  • W_e (nn.Linear) – Embedding layer for edges

  • node_encoder (nn.ModuleList of TERMNodeTransformerLayer or TERMNodeMPNNLayer) – Update layers for nodes

  • edge_encoder (nn.ModuleList of TERMEdgeTransformerLayer or TERMEdgeMPNNLayer) – Update layers for edges

  • W_out (nn.Linear) – Output layer

__init__(hparams)[source]
Parameters:

hparams (dict) – Dictionary of model hparams (see ~/scripts/models/train/default_hparams.json for more info)

Methods

__init__(hparams)

Parameters:

hparams (dict) -- Dictionary of model hparams (see ~/scripts/models/train/default_hparams.json for more info)

forward(V, E, E_idx, mask[, contact_idx])

Refine TERM graph representations

Attributes

T_destination

alias of TypeVar('T_destination', bound=Mapping[str, Tensor])

dump_patches

This allows better BC support for load_state_dict().

forward(V, E, E_idx, mask, contact_idx=None)[source]

Refine TERM graph representations

Parameters:
  • V (torch.Tensor) – Node embeddings Shape: n_batches x n_terms x max_term_len x n_hidden

  • E (torch.Tensor) – Edge embeddings in kNN dense form Shape: n_batches x n_terms x max_term_len x max_term_len x n_hidden

  • E_idx (torch.LongTensor) – Edge indices Shape: n_batches x n_terms x max_term_len x max_term_len

  • mask (torch.ByteTensor) – Mask for TERM resides Shape: n_batches x n_terms x max_term_len

  • contact_idx (torch.Tensor) – Embedded contact indices Shape: n_batches x n_terms x max_term_len x n_hidden

Returns:

  • h_V (torch.Tensor) – TERM node embeddings

  • h_E (torch.Tensor) – TERM edge embeddings