terminator.models.layers.utils.aggregate_edges

terminator.models.layers.utils.aggregate_edges(edge_embeddings, E_idx, max_seq_len)[source]

Aggregate TERM edge embeddings into a sequence-level dense edge features tensor

Parameters:
  • edge_embeddings (torch.Tensor) – TERM edge features tensor Shape : n_batch x n_terms x n_aa x n_neighbors x n_hidden

  • E_idx (torch.LongTensor) – TERM edge indices Shape : n_batch x n_terms x n_aa x n_neighbors

  • max_seq_len (int) – Max length of a sequence in the batch

Returns:

Dense sequence-level edge features Shape : n_batch x max_seq_len x max_seq_len x n_hidden

Return type:

torch.Tensor