terminator.models.layers.term.graph.s2s.TERMEdgeTransformerLayer¶
- class terminator.models.layers.term.graph.s2s.TERMEdgeTransformerLayer(num_hidden, num_in, num_heads=4, dropout=0.1)[source]¶
Bases:
ModuleTERM Edge Transformer Layer
A TERM Edge Transformer Layer that updates edges via TERMEdgeEndpointAttention
- Variables:
attention (TERMEdgeEndpointAttention) – Transformer Attention mechanism
dense (PositionWiseFeedForward) – Transformer position-wise FFN
- __init__(num_hidden, num_in, num_heads=4, dropout=0.1)[source]¶
- Parameters:
num_hidden (int) – Hidden dimension, and dimensionality of querys in TERMNeighborAttention
num_in (int) – Dimensionality of keys and values
num_heads (int, default=4) – Number of heads to use in TERMNeighborAttention
dropout (float, default=0.1) – Dropout rate
Methods
__init__(num_hidden, num_in[, num_heads, ...])- Parameters:
num_hidden (int) -- Hidden dimension, and dimensionality of querys in TERMNeighborAttention
forward(h_E, h_EV, E_idx[, mask_E, mask_attend])Apply one Transformer update on edges in a TERM graph
Attributes
T_destinationalias of TypeVar('T_destination', bound=
Mapping[str,Tensor])dump_patchesThis allows better BC support for
load_state_dict().- forward(h_E, h_EV, E_idx, mask_E=None, mask_attend=None)[source]¶
Apply one Transformer update on edges in a TERM graph
- Parameters:
h_E (torch.Tensor) – Edge features in kNN dense form Shape: n_batch x n_terms x n_nodes x k x n_hidden
h_EV (torch.Tensor) – ‘Neighbor’ edge features, or all edges which share a ‘central residue’ with that edge, as well as the node features for both nodes that compose that edge. Shape: n_batch x n_terms x n_nodes x k x n_in
mask_E (torch.ByteTensor or None) – Mask for attention regarding TERM edges Shape : n_batch x n_terms x n_nodes
mask_attend (torch.ByteTensor or None) – Mask for attention regarding ‘neighbor’ edges Shape: n_batch x n_terms x n_nodes x k
- Returns:
h_E – Updated edge embeddings Shape: n_batch x n_terms x n_nodes x k x n_hidden
- Return type:
torch.Tensor