terminator.models.layers.utils.gather_term_edges

terminator.models.layers.utils.gather_term_edges(edges, neighbor_idx)[source]

Gather the TERM edge features of the nearest neighbors.

From https://github.com/jingraham/neurips19-graph-protein-design

Parameters:
  • edges (torch.Tensor) – The edge features in dense form Shape: n_batch x n_terms x n_res x n_res x n_hidden

  • neighbor_idx (torch.LongTensor) – kNN sparse edge indices Shape : n_batch x n_terms x n_res x k

Returns:

edge_features – The gathered edge features Shape : n_batch x n_terms x n_res x k x n_hidden

Return type:

torch.Tensor