terminator.models.layers.term.matches.attn.TERMMatchTransformerEncoder¶
- class terminator.models.layers.term.matches.attn.TERMMatchTransformerEncoder(hparams)[source]¶
Bases:
ModuleTERM Match Transformer Encoder
A Transformer which uses a pool token to summarize the contents of TERM matches
- Variables:
W_v (nn.Linear) – Embedding layer for matches
W_t (nn.Linear) – Embedding layer for target structure information
W_pool (nn.Linear) – Embedding layer for pool token
encoder_layers (nn.ModuleList of TERMMatchTransformerLayer) – Transformer layers for matches
W_out (nn.Linear) – Output layer
pool_token_init (nn.Parameter) – The embedding for the pool token used to gather information, reminiscent of [CLS] tokens in BERT
- __init__(hparams)[source]¶
- Parameters:
hparams (dict) – Dictionary of model hparams (see
~/scripts/models/train/default_hparams.jsonfor more info)
Methods
__init__(hparams)- Parameters:
hparams (dict) -- Dictionary of model hparams (see
~/scripts/models/train/default_hparams.jsonfor more info)
forward(V, target, mask)Summarize TERM matches
Attributes
T_destinationalias of TypeVar('T_destination', bound=
Mapping[str,Tensor])dump_patchesThis allows better BC support for
load_state_dict().- forward(V, target, mask)[source]¶
Summarize TERM matches
- Parameters:
V (torch.Tensor) – TERM Match embedding Shape: n_batches x sum_term_len x n_matches x n_hidden
target (torch.Tensor) – Embedded structural information of target per TERM residue Shape: n_batches x sum_term_len x n_hidden
mask (torch.ByteTensor) – Mask for TERM resides Shape: n_batches x sum_term_len
- Returns:
Summarized TERM matches Shape: n_batches x sum_term_len x n_hidden
- Return type:
torch.Tensor