terminator.models.layers.utils.gather_term_nodes¶
- terminator.models.layers.utils.gather_term_nodes(nodes, neighbor_idx)[source]¶
Gather TERM node features of nearest neighbors.
Adatped from https://github.com/jingraham/neurips19-graph-protein-design
- Parameters:
nodes (torch.Tensor) – The node features for all nodes Shape: n_batch x n_terms x n_res x n_hidden
neighbor_idx (torch.LongTensor) – kNN sparse edge indices Shape : n_batch x n_terms x n_res x k
- Returns:
neighbor_features – The gathered neighbor node features Shape : n_batch x n_terms x n_res x k x n_hidden
- Return type:
torch.Tensor