terminator.models.layers.term.matches.attn.TERMMatchTransformerLayer¶
- class terminator.models.layers.term.matches.attn.TERMMatchTransformerLayer(hparams)[source]¶
Bases:
ModuleTERM Match Transformer Layer
A TERM Match Transformer Layer that updates match embeddings via TERMMatchATtention
- Variables:
attention (TERMMatchAttention) – Transformer Attention mechanism
dense (PositionWiseFeedForward) – Transformer position-wise FFN
- __init__(hparams)[source]¶
- Parameters:
hparams (dict) – Dictionary of model hparams (see
~/scripts/models/train/default_hparams.jsonfor more info)
Methods
__init__(hparams)- Parameters:
hparams (dict) -- Dictionary of model hparams (see
~/scripts/models/train/default_hparams.jsonfor more info)
forward(src, target[, src_mask, ...])Apply one Transformer update to TERM matches
Attributes
T_destinationalias of TypeVar('T_destination', bound=
Mapping[str,Tensor])dump_patchesThis allows better BC support for
load_state_dict().- forward(src, target, src_mask=None, mask_attend=None, checkpoint=False)[source]¶
Apply one Transformer update to TERM matches
- Parameters:
src (torch.Tensor) – TERM Match features Shape: n_batch x sum_term_len x n_matches x n_hidden
target (torch.Tensor) – Embedded structural features per TERM residue of target structure Shape: n_batch x sum_term_len x n_matches x n_hidden
src_mask (torch.ByteTensor or None) – Mask for attention regarding TERM residues Shape : n_batch x sum_term_len
mask_attend (torch.ByteTensor or None) – Mask for attention regarding matches Shape: n_batch x sum_term_len # TODO: check shape
checkpoint (bool, default=False) – Whether to use gradient checkpointing to reduce memory usage
- Returns:
src – Updated match embeddings Shape: n_batch x sum_term_len x n_matches x n_hidden
- Return type:
torch.Tensor