terminator.models.layers.term.matches.attn.TERMMatchAttention

class terminator.models.layers.term.matches.attn.TERMMatchAttention(hparams)[source]

Bases: Module

TERM Match Attention

A module with computes a node update using self-attention over all neighboring TERM residues and the edges connecting them.

Variables:
  • W_Q (nn.Linear) – Projection matrix for querys

  • W_K (nn.Linear) – Projection matrix for keys

  • W_V (nn.Linear) – Projection matrix for values

  • W_O (nn.Linear) – Output layer

__init__(hparams)[source]
Parameters:

hparams (dict) – Dictionary of model hparams (see ~/scripts/models/train/default_hparams.json for more info)

Methods

__init__(hparams)

Parameters:

hparams (dict) -- Dictionary of model hparams (see ~/scripts/models/train/default_hparams.json for more info)

forward(h_V, h_T[, mask_attend])

Self-attention update over residues in TERM matches

Attributes

T_destination

alias of TypeVar('T_destination', bound=Mapping[str, Tensor])

dump_patches

This allows better BC support for load_state_dict().

_masked_softmax(attend_logits, mask_attend, dim=- 1)[source]

Numerically stable masked softmax

forward(h_V, h_T, mask_attend=None)[source]

Self-attention update over residues in TERM matches

Parameters:
  • h_V (torch.Tensor) – TERM match residues Shape: n_batch x sum_term_len x n_matches x n_hidden

  • h_T (torch.Tensor) – Embedded structural features of target residue Shape: n_batch x sum_term_len x n_hidden

  • mask_attend (torch.ByteTensor or None) – Mask for attention Shape: n_batch x sum_term_len # TODO: check shape

Returns:

src_update – TERM matches embedding update Shape: n_batch x sum_term_len x n_matches x n_hidden

Return type:

torch.Tensor