terminator.models.layers.term.graph.s2s.TERMEdgeEndpointAttention

class terminator.models.layers.term.graph.s2s.TERMEdgeEndpointAttention(num_hidden, num_in, num_heads=4)[source]

Bases: Module

TERM Edge Endpoint Attention

A module which computes an edge update using self-attention over all edges that it share a ‘home residue’ with, as well as the nodes that form those edges.

Variables:
  • W_Q (nn.Linear) – Projection matrix for querys

  • W_K (nn.Linear) – Projection matrix for keys

  • W_V (nn.Linear) – Projection matrix for values

  • W_O (nn.Linear) – Output layer

__init__(num_hidden, num_in, num_heads=4)[source]
Parameters:
  • num_hidden (int) – Hidden dimension, and dimensionality of querys

  • num_in (int) – Dimensionality of keys and values

  • num_heads (int, default=4) – Number of heads to use in Attention

Methods

__init__(num_hidden, num_in[, num_heads])

Parameters:
  • num_hidden (int) -- Hidden dimension, and dimensionality of querys

forward(h_E, h_EV, E_idx[, mask_attend])

Self-attention update over edges in a TERM graph

Attributes

T_destination

alias of TypeVar('T_destination', bound=Mapping[str, Tensor])

dump_patches

This allows better BC support for load_state_dict().

_masked_softmax(attend_logits, mask_attend, dim=- 1)[source]

Numerically stable masked softmax

Parameters:
  • attend_logits (torch.Tensor) – Attention logits

  • mask_attend (torch.ByteTensor) – Mask on Attention logits

  • dim (int, default=-1) – Dimension to perform softmax along

Returns:

attend – Softmaxed attend_logits

Return type:

torch.Tensor

forward(h_E, h_EV, E_idx, mask_attend=None)[source]

Self-attention update over edges in a TERM graph

Parameters:
  • h_E (torch.Tensor) – Edge features in kNN dense form Shape: n_batch x n_terms x n_nodes x k x n_hidden

  • h_EV (torch.Tensor) – ‘Neighbor’ edge features, or all edges which share a ‘central residue’ with that edge, as well as the node features for both nodes that compose that edge. Shape: n_batch x n_terms x n_nodes x k x n_in

  • mask_attend (torch.ByteTensor or None) – Mask for attention regarding neighbors Shape: n_batch x n_terms x n_nodes x k

Returns:

h_E_update – Update for edge embeddings Shape: n_batch x n_terms x n_nodes x k x n_hidden

Return type:

torch.Tensor