terminator.models.layers.term.matches.attn.TERMMatchTransformerLayer

class terminator.models.layers.term.matches.attn.TERMMatchTransformerLayer(hparams)[source]

Bases: Module

TERM Match Transformer Layer

A TERM Match Transformer Layer that updates match embeddings via TERMMatchATtention

Variables:
__init__(hparams)[source]
Parameters:

hparams (dict) – Dictionary of model hparams (see ~/scripts/models/train/default_hparams.json for more info)

Methods

__init__(hparams)

Parameters:

hparams (dict) -- Dictionary of model hparams (see ~/scripts/models/train/default_hparams.json for more info)

forward(src, target[, src_mask, ...])

Apply one Transformer update to TERM matches

Attributes

T_destination

alias of TypeVar('T_destination', bound=Mapping[str, Tensor])

dump_patches

This allows better BC support for load_state_dict().

forward(src, target, src_mask=None, mask_attend=None, checkpoint=False)[source]

Apply one Transformer update to TERM matches

Parameters:
  • src (torch.Tensor) – TERM Match features Shape: n_batch x sum_term_len x n_matches x n_hidden

  • target (torch.Tensor) – Embedded structural features per TERM residue of target structure Shape: n_batch x sum_term_len x n_matches x n_hidden

  • src_mask (torch.ByteTensor or None) – Mask for attention regarding TERM residues Shape : n_batch x sum_term_len

  • mask_attend (torch.ByteTensor or None) – Mask for attention regarding matches Shape: n_batch x sum_term_len # TODO: check shape

  • checkpoint (bool, default=False) – Whether to use gradient checkpointing to reduce memory usage

Returns:

src – Updated match embeddings Shape: n_batch x sum_term_len x n_matches x n_hidden

Return type:

torch.Tensor