terminator.models.layers.s2s_modules.EdgeEndpointAttention¶
- class terminator.models.layers.s2s_modules.EdgeEndpointAttention(num_hidden, num_in, num_heads=4)[source]¶
Bases:
Module- __init__(num_hidden, num_in, num_heads=4)[source]¶
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Methods
__init__(num_hidden, num_in[, num_heads])Initializes internal Module state, shared by both nn.Module and ScriptModule.
forward(h_E, h_EV, E_idx[, mask_attend])Self-attention, graph-structured O(Nk) Args: h_E: Edge features [N_batch, N_nodes, K, N_hidden] h_EV: Edge + endpoint features [N_batch, N_nodes, K, N_hidden * 3] mask_attend: Mask for attention [N_batch, N_nodes, K] Returns: h_E_update Edge update
Attributes
T_destinationalias of TypeVar('T_destination', bound=
Mapping[str,Tensor])dump_patchesThis allows better BC support for
load_state_dict().- forward(h_E, h_EV, E_idx, mask_attend=None)[source]¶
Self-attention, graph-structured O(Nk) Args:
h_E: Edge features [N_batch, N_nodes, K, N_hidden] h_EV: Edge + endpoint features [N_batch, N_nodes, K, N_hidden * 3] mask_attend: Mask for attention [N_batch, N_nodes, K]
- Returns:
h_E_update Edge update