terminator.models.layers.term.graph.s2s.TERMNeighborAttention¶
- class terminator.models.layers.term.graph.s2s.TERMNeighborAttention(num_hidden, num_in, num_heads=4)[source]¶
Bases:
ModuleTERM Neighbor Attention
A module which computes a node update using self-attention over all neighboring TERM residues and the edges connecting them.
- Variables:
W_Q (nn.Linear) – Projection matrix for querys
W_K (nn.Linear) – Projection matrix for keys
W_V (nn.Linear) – Projection matrix for values
W_O (nn.Linear) – Output layer
- __init__(num_hidden, num_in, num_heads=4)[source]¶
- Parameters:
num_hidden (int) – Hidden dimension, and dimensionality of querys
num_in (int) – Dimensionality of keys and values
num_heads (int, default=4) – Number of heads to use in Attention
Methods
__init__(num_hidden, num_in[, num_heads])- Parameters:
num_hidden (int) -- Hidden dimension, and dimensionality of querys
forward(h_V, h_EV[, mask_attend])Self-attention update over nodes of a TERM graph
Attributes
T_destinationalias of TypeVar('T_destination', bound=
Mapping[str,Tensor])dump_patchesThis allows better BC support for
load_state_dict().- _masked_softmax(attend_logits, mask_attend, dim=- 1)[source]¶
Numerically stable masked softmax
- Parameters:
attend_logits (torch.Tensor) – Attention logits
mask_attend (torch.ByteTensor) – Mask on Attention logits
dim (int, default=-1) – Dimension to perform softmax along
- Returns:
attend – Softmaxed
attend_logits- Return type:
torch.Tensor
- forward(h_V, h_EV, mask_attend=None)[source]¶
Self-attention update over nodes of a TERM graph
- Parameters:
h_V (torch.Tensor) – Central node features Shape: n_batch x n_terms x n_nodes x n_hidden
h_EV (torch.Tensor) – Neighbor features, which includes the node vector concatenated onto the edge connecting the central node to the neighbor node Shape: n_batch x n_terms x n_nodes x n_neighbors x n_in
mask_attend (torch.ByteTensor or None) – Mask for attention regarding neighbors Shape: n_batch x n_terms x n_nodes x k
- Returns:
h_V_update – Node embedding update Shape: n_batch x n_terms x n_nodes x n_hidden
- Return type:
torch.Tensor