terminator.models.layers.utils.merge_duplicate_edges_geometric¶
- terminator.models.layers.utils.merge_duplicate_edges_geometric(h_E_update, edge_index)[source]¶
Average embeddings across bidirectional edges for Torch Geometric graphs
TERMinator edges are represented as two bidirectional edges, and to allow for communication between these edges we average the embeddings.
- Parameters:
h_E_update (torch.Tensor) – Update tensor for edges embeddings in Torch Geometric sparse form Shape : n_edge x n_hidden
edge_index (torch.LongTensor) – Torch Geometric sparse edge indices Shape : 2 x n_edge
- Returns:
merged_E_updates – Edge update with merged updates for bidirectional edges Shape : n_edge x n_hidden
- Return type:
torch.Tensor