terminator.models.layers.term.graph.s2s.TERMNodeMPNNLayer

class terminator.models.layers.term.graph.s2s.TERMNodeMPNNLayer(num_hidden, num_in, dropout=0.1, num_heads=None, scale=None)[source]

Bases: Module

TERM Node MPNN Layer

A TERM Node MPNN Layer that updates nodes via generating messages and feeding the update through a feedforward network

Variables:
  • W3 (W1, W2,) – Layers for message computation

  • dense (PositionWiseFeedForward) – Transformer position-wise FFN

__init__(num_hidden, num_in, dropout=0.1, num_heads=None, scale=None)[source]
Parameters:
  • num_hidden (int) – Hidden dimension, and dimensionality of querys in TERMNeighborAttention

  • num_in (int) – Dimensionality of keys and values

  • num_heads (int, default=4) – Number of heads to use in TERMNeighborAttention

  • dropout (float, default=0.1) – Dropout rate

  • scale (int or None, default=None) – Scaling integer by which to divde the sum of computed messages. If None, the mean of the messages will be used instead.

Methods

__init__(num_hidden, num_in[, dropout, ...])

Parameters:
  • num_hidden (int) -- Hidden dimension, and dimensionality of querys in TERMNeighborAttention

forward(h_V, h_E[, mask_V, mask_attend])

Apply one MPNN update on nodes in a TERM graph

Attributes

T_destination

alias of TypeVar('T_destination', bound=Mapping[str, Tensor])

dump_patches

This allows better BC support for load_state_dict().

forward(h_V, h_E, mask_V=None, mask_attend=None)[source]

Apply one MPNN update on nodes in a TERM graph

Parameters:
  • h_V (torch.Tensor) – Central node features Shape: n_batch x n_terms x n_nodes x n_hidden

  • h_VE (torch.Tensor) – Neighbor features, which includes the node vector concatenated onto the edge connecting the central node to the neighbor node Shape: n_batch x n_terms x n_nodes x n_neighbors x n_in

  • mask_V (torch.ByteTensor or None) – Mask for message-passing regarding TERM residues Shape : n_batch x n_terms x n_nodes

  • mask_attend (torch.ByteTensor or None) – Mask for message-passing regarding neighbors Shape: n_batch x n_terms x n_nodes x k

Returns:

h_V – Updated node embeddings Shape: n_batch x n_terms x n_nodes x n_hidden

Return type:

torch.Tensor