terminator.models.layers.utils.merge_duplicate_term_edges¶
- terminator.models.layers.utils.merge_duplicate_term_edges(h_E_update, E_idx)[source]¶
Average embeddings across bidirectional TERM edges.
TERMinator edges are represented as two bidirectional edges, and to allow for communication between these edges we average the embeddings.
- Parameters:
h_E_update (torch.Tensor) – Update tensor for edges embeddings in kNN sparse form Shape : n_batch x n_terms x n_res x k x n_hidden
E_idx (torch.LongTensor) – kNN sparse edge indices Shape : n_batch x n_terms x n_res x k
- Returns:
merged_E_updates – Edge update with merged updates for bidirectional edges Shape : n_batch x n_terms x n_res x k x n_hidden
- Return type:
torch.Tensor