terminator.models.layers.s2s_modules.NeighborAttention

class terminator.models.layers.s2s_modules.NeighborAttention(num_hidden, num_in, num_heads=4)[source]

Bases: Module

__init__(num_hidden, num_in, num_heads=4)[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Methods

__init__(num_hidden, num_in[, num_heads])

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(h_V, h_E[, mask_attend])

Self-attention, graph-structured O(Nk) Args: h_V: Node features [N_batch, N_nodes, N_hidden] h_E: Neighbor features [N_batch, N_nodes, K, N_hidden] mask_attend: Mask for attention [N_batch, N_nodes, K] Returns: h_V: Node update

step(t, h_V, h_E, E_idx[, mask_attend])

Self-attention for a specific time step t

Attributes

T_destination

alias of TypeVar('T_destination', bound=Mapping[str, Tensor])

dump_patches

This allows better BC support for load_state_dict().

_masked_softmax(attend_logits, mask_attend, dim=- 1)[source]

Numerically stable masked softmax

forward(h_V, h_E, mask_attend=None)[source]

Self-attention, graph-structured O(Nk) Args:

h_V: Node features [N_batch, N_nodes, N_hidden] h_E: Neighbor features [N_batch, N_nodes, K, N_hidden] mask_attend: Mask for attention [N_batch, N_nodes, K]

Returns:

h_V: Node update

step(t, h_V, h_E, E_idx, mask_attend=None)[source]

Self-attention for a specific time step t

Args:

h_V: Node features [N_batch, N_nodes, N_hidden] h_E: Neighbor features [N_batch, N_nodes, K, N_in] E_idx: Neighbor indices [N_batch, N_nodes, K] mask_attend: Mask for attention [N_batch, N_nodes, K]

Returns:

h_V_t: Node update