terminator.models.layers.utils.merge_duplicate_edges

terminator.models.layers.utils.merge_duplicate_edges(h_E_update, E_idx)[source]

Average embeddings across bidirectional edges.

TERMinator edges are represented as two bidirectional edges, and to allow for communication between these edges we average the embeddings.

Parameters:
  • h_E_update (torch.Tensor) – Update tensor for edges embeddings in kNN sparse form Shape : n_batch x n_res x k x n_hidden

  • E_idx (torch.LongTensor) – kNN sparse edge indices Shape : n_batch x n_res x k

Returns:

merged_E_updates – Edge update with merged updates for bidirectional edges Shape : n_batch x n_res x k x n_hidden

Return type:

torch.Tensor