terminator.models.layers.condense.covariation_features¶
- terminator.models.layers.condense.covariation_features(matches, term_lens, rmsds, mask)[source]¶
Compute weighted cross-covariance features from TERM matches
- Parameters:
matches (torch.Tensor) – TERM matches, in flat form (TERMs are cat’d side by side) Shape: n_batch x sum_term_len x n_hidden
term_lens (list of (list of int)) – Length of each TERM
rmsds (torch.Tensor) – RMSD per TERM match Shape: n_batch x sum_term_len
mask (torch.ByteTensor) – Mask for TERM residues Shape: n_batch x sum_term_len
- Returns:
cov_mat – Weighted cross-covariance matrices Shape: n_batch x n_terms x max_term_len x max_term_len x n_hidden x n_hidden
- Return type:
torch.Tensor