terminator.models.layers.term.matches.attn.TERMMatchAttention¶
- class terminator.models.layers.term.matches.attn.TERMMatchAttention(hparams)[source]¶
Bases:
ModuleTERM Match Attention
A module with computes a node update using self-attention over all neighboring TERM residues and the edges connecting them.
- Variables:
W_Q (nn.Linear) – Projection matrix for querys
W_K (nn.Linear) – Projection matrix for keys
W_V (nn.Linear) – Projection matrix for values
W_O (nn.Linear) – Output layer
- __init__(hparams)[source]¶
- Parameters:
hparams (dict) – Dictionary of model hparams (see
~/scripts/models/train/default_hparams.jsonfor more info)
Methods
__init__(hparams)- Parameters:
hparams (dict) -- Dictionary of model hparams (see
~/scripts/models/train/default_hparams.jsonfor more info)
forward(h_V, h_T[, mask_attend])Self-attention update over residues in TERM matches
Attributes
T_destinationalias of TypeVar('T_destination', bound=
Mapping[str,Tensor])dump_patchesThis allows better BC support for
load_state_dict().- forward(h_V, h_T, mask_attend=None)[source]¶
Self-attention update over residues in TERM matches
- Parameters:
h_V (torch.Tensor) – TERM match residues Shape: n_batch x sum_term_len x n_matches x n_hidden
h_T (torch.Tensor) – Embedded structural features of target residue Shape: n_batch x sum_term_len x n_hidden
mask_attend (torch.ByteTensor or None) – Mask for attention Shape: n_batch x sum_term_len # TODO: check shape
- Returns:
src_update – TERM matches embedding update Shape: n_batch x sum_term_len x n_matches x n_hidden
- Return type:
torch.Tensor