terminator.models.layers.utils.merge_duplicate_edges_geometric

terminator.models.layers.utils.merge_duplicate_edges_geometric(h_E_update, edge_index)[source]

Average embeddings across bidirectional edges for Torch Geometric graphs

TERMinator edges are represented as two bidirectional edges, and to allow for communication between these edges we average the embeddings.

Parameters:
  • h_E_update (torch.Tensor) – Update tensor for edges embeddings in Torch Geometric sparse form Shape : n_edge x n_hidden

  • edge_index (torch.LongTensor) – Torch Geometric sparse edge indices Shape : 2 x n_edge

Returns:

merged_E_updates – Edge update with merged updates for bidirectional edges Shape : n_edge x n_hidden

Return type:

torch.Tensor