terminator.models.layers.term.matches.attn.TERMMatchTransformerEncoder

class terminator.models.layers.term.matches.attn.TERMMatchTransformerEncoder(hparams)[source]

Bases: Module

TERM Match Transformer Encoder

A Transformer which uses a pool token to summarize the contents of TERM matches

Variables:
  • W_v (nn.Linear) – Embedding layer for matches

  • W_t (nn.Linear) – Embedding layer for target structure information

  • W_pool (nn.Linear) – Embedding layer for pool token

  • encoder_layers (nn.ModuleList of TERMMatchTransformerLayer) – Transformer layers for matches

  • W_out (nn.Linear) – Output layer

  • pool_token_init (nn.Parameter) – The embedding for the pool token used to gather information, reminiscent of [CLS] tokens in BERT

__init__(hparams)[source]
Parameters:

hparams (dict) – Dictionary of model hparams (see ~/scripts/models/train/default_hparams.json for more info)

Methods

__init__(hparams)

Parameters:

hparams (dict) -- Dictionary of model hparams (see ~/scripts/models/train/default_hparams.json for more info)

forward(V, target, mask)

Summarize TERM matches

Attributes

T_destination

alias of TypeVar('T_destination', bound=Mapping[str, Tensor])

dump_patches

This allows better BC support for load_state_dict().

forward(V, target, mask)[source]

Summarize TERM matches

Parameters:
  • V (torch.Tensor) – TERM Match embedding Shape: n_batches x sum_term_len x n_matches x n_hidden

  • target (torch.Tensor) – Embedded structural information of target per TERM residue Shape: n_batches x sum_term_len x n_hidden

  • mask (torch.ByteTensor) – Mask for TERM resides Shape: n_batches x sum_term_len

Returns:

Summarized TERM matches Shape: n_batches x sum_term_len x n_hidden

Return type:

torch.Tensor