terminator.models.layers.term.graph.s2s.TERMNodeTransformerLayer¶
- class terminator.models.layers.term.graph.s2s.TERMNodeTransformerLayer(num_hidden, num_in, num_heads=4, dropout=0.1)[source]¶
Bases:
ModuleTERM Node Transformer Layer
A TERM Node Transformer Layer that updates nodes via TERMNeighborAttention
- Variables:
attention (TERMNeighborAttention) – Transformer Attention mechanism
dense (PositionWiseFeedForward) – Transformer position-wise FFN
- __init__(num_hidden, num_in, num_heads=4, dropout=0.1)[source]¶
- Parameters:
num_hidden (int) – Hidden dimension, and dimensionality of querys in TERMNeighborAttention
num_in (int) – Dimensionality of keys and values
num_heads (int, default=4) – Number of heads to use in TERMNeighborAttention
dropout (float, default=0.1) – Dropout rate
Methods
__init__(num_hidden, num_in[, num_heads, ...])- Parameters:
num_hidden (int) -- Hidden dimension, and dimensionality of querys in TERMNeighborAttention
forward(h_V, h_VE[, mask_V, mask_attend])Apply one Transformer update on nodes in a TERM graph
Attributes
T_destinationalias of TypeVar('T_destination', bound=
Mapping[str,Tensor])dump_patchesThis allows better BC support for
load_state_dict().- forward(h_V, h_VE, mask_V=None, mask_attend=None)[source]¶
Apply one Transformer update on nodes in a TERM graph
- Parameters:
h_V (torch.Tensor) – Central node features Shape: n_batch x n_terms x n_nodes x n_hidden
h_VE (torch.Tensor) – Neighbor features, which includes the node vector concatenated onto the edge connecting the central node to the neighbor node Shape: n_batch x n_terms x n_nodes x n_neighbors x n_in
mask_V (torch.ByteTensor or None) – Mask for attention regarding TERM residues Shape : n_batch x n_terms x n_nodes
mask_attend (torch.ByteTensor or None) – Mask for attention regarding neighbors Shape: n_batch x n_terms x n_nodes x k
- Returns:
h_V – Updated node embeddings Shape: n_batch x n_terms x n_nodes x n_hidden
- Return type:
torch.Tensor