terminator.models.layers.utils.cat_term_edge_endpoints

terminator.models.layers.utils.cat_term_edge_endpoints(h_edges, h_nodes, E_idx)[source]

Concatenate both node features onto the ends of gathered edge features given kNN sparse edge indices

Parameters:
  • h_edges (torch.Tensor) – The gathered edge features Shape: n_batch x n_terms x n_res x k x n_hidden

  • h_nodes (torch.Tensor) – The node features for all nodes Shape: n_batch x n_terms x n_res x n_hidden

  • E_idx (torch.LongTensor) – kNN sparse edge indices Shape : n_batch x n_terms x n_res x k

Returns:

h_nn – The gathered concatenated node and edge features Shape : n_batch x n_terms x n_res x k x n_hidden

Return type:

torch.Tensor