terminator.models.layers.condense.covariation_features

terminator.models.layers.condense.covariation_features(matches, term_lens, rmsds, mask)[source]

Compute weighted cross-covariance features from TERM matches

Parameters:
  • matches (torch.Tensor) – TERM matches, in flat form (TERMs are cat’d side by side) Shape: n_batch x sum_term_len x n_hidden

  • term_lens (list of (list of int)) – Length of each TERM

  • rmsds (torch.Tensor) – RMSD per TERM match Shape: n_batch x sum_term_len

  • mask (torch.ByteTensor) – Mask for TERM residues Shape: n_batch x sum_term_len

Returns:

cov_mat – Weighted cross-covariance matrices Shape: n_batch x n_terms x max_term_len x max_term_len x n_hidden x n_hidden

Return type:

torch.Tensor