terminator.models.layers.s2s_modules.EdgeEndpointAttention

class terminator.models.layers.s2s_modules.EdgeEndpointAttention(num_hidden, num_in, num_heads=4)[source]

Bases: Module

__init__(num_hidden, num_in, num_heads=4)[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Methods

__init__(num_hidden, num_in[, num_heads])

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(h_E, h_EV, E_idx[, mask_attend])

Self-attention, graph-structured O(Nk) Args: h_E: Edge features [N_batch, N_nodes, K, N_hidden] h_EV: Edge + endpoint features [N_batch, N_nodes, K, N_hidden * 3] mask_attend: Mask for attention [N_batch, N_nodes, K] Returns: h_E_update Edge update

Attributes

T_destination

alias of TypeVar('T_destination', bound=Mapping[str, Tensor])

dump_patches

This allows better BC support for load_state_dict().

_masked_softmax(attend_logits, mask_attend, dim=- 1)[source]

Numerically stable masked softmax

forward(h_E, h_EV, E_idx, mask_attend=None)[source]

Self-attention, graph-structured O(Nk) Args:

h_E: Edge features [N_batch, N_nodes, K, N_hidden] h_EV: Edge + endpoint features [N_batch, N_nodes, K, N_hidden * 3] mask_attend: Mask for attention [N_batch, N_nodes, K]

Returns:

h_E_update Edge update