terminator.models.layers.utils.aggregate_edges¶
- terminator.models.layers.utils.aggregate_edges(edge_embeddings, E_idx, max_seq_len)[source]¶
Aggregate TERM edge embeddings into a sequence-level dense edge features tensor
- Parameters:
edge_embeddings (torch.Tensor) – TERM edge features tensor Shape : n_batch x n_terms x n_aa x n_neighbors x n_hidden
E_idx (torch.LongTensor) – TERM edge indices Shape : n_batch x n_terms x n_aa x n_neighbors
max_seq_len (int) – Max length of a sequence in the batch
- Returns:
Dense sequence-level edge features Shape : n_batch x max_seq_len x max_seq_len x n_hidden
- Return type:
torch.Tensor