terminator.models.layers.term.graph.s2s.TERMNodeMPNNLayer¶
- class terminator.models.layers.term.graph.s2s.TERMNodeMPNNLayer(num_hidden, num_in, dropout=0.1, num_heads=None, scale=None)[source]¶
Bases:
ModuleTERM Node MPNN Layer
A TERM Node MPNN Layer that updates nodes via generating messages and feeding the update through a feedforward network
- Variables:
W3 (W1, W2,) – Layers for message computation
dense (PositionWiseFeedForward) – Transformer position-wise FFN
- __init__(num_hidden, num_in, dropout=0.1, num_heads=None, scale=None)[source]¶
- Parameters:
num_hidden (int) – Hidden dimension, and dimensionality of querys in TERMNeighborAttention
num_in (int) – Dimensionality of keys and values
num_heads (int, default=4) – Number of heads to use in TERMNeighborAttention
dropout (float, default=0.1) – Dropout rate
scale (int or None, default=None) – Scaling integer by which to divde the sum of computed messages. If None, the mean of the messages will be used instead.
Methods
__init__(num_hidden, num_in[, dropout, ...])- Parameters:
num_hidden (int) -- Hidden dimension, and dimensionality of querys in TERMNeighborAttention
forward(h_V, h_E[, mask_V, mask_attend])Apply one MPNN update on nodes in a TERM graph
Attributes
T_destinationalias of TypeVar('T_destination', bound=
Mapping[str,Tensor])dump_patchesThis allows better BC support for
load_state_dict().- forward(h_V, h_E, mask_V=None, mask_attend=None)[source]¶
Apply one MPNN update on nodes in a TERM graph
- Parameters:
h_V (torch.Tensor) – Central node features Shape: n_batch x n_terms x n_nodes x n_hidden
h_VE (torch.Tensor) – Neighbor features, which includes the node vector concatenated onto the edge connecting the central node to the neighbor node Shape: n_batch x n_terms x n_nodes x n_neighbors x n_in
mask_V (torch.ByteTensor or None) – Mask for message-passing regarding TERM residues Shape : n_batch x n_terms x n_nodes
mask_attend (torch.ByteTensor or None) – Mask for message-passing regarding neighbors Shape: n_batch x n_terms x n_nodes x k
- Returns:
h_V – Updated node embeddings Shape: n_batch x n_terms x n_nodes x n_hidden
- Return type:
torch.Tensor