Source code for terminator.models.layers.utils

""" Util functions useful in TERMinator modules """
import sys

import torch
from torch.nn.utils.rnn import pad_sequence

# pylint: disable=no-member

# batchify functions



[docs]def pad_sequence_12(sequences, padding_value=0): """Given a sequence of tensors, batch them together by pads both dims 1 and 2 to max length. Args ---- sequences : list of torch.Tensor Sequence of tensors with number of axes `N >= 2` padding value : int, default=0 What value to pad the tensors with Returns ------- out_tensor : torch.Tensor Batched tensor with shape (n_batch, max_dim1, max_dim2, ...) """ n_batches = len(sequences) out_dims = list(sequences[0].size()) dim1, dim2 = 0, 1 max_dim1 = max([s.size(dim1) for s in sequences]) max_dim2 = max([s.size(dim2) for s in sequences]) out_dims[dim1] = max_dim1 out_dims[dim2] = max_dim2 out_dims = [n_batches] + out_dims out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value) for i, tensor in enumerate(sequences): len1 = tensor.size(0) len2 = tensor.size(1) # use index notation to prevent duplicate references to the tensor out_tensor[i, :len1, :len2, ...] = tensor return out_tensor
[docs]def batchify(batched_flat_terms, term_lens): """ Take a flat representation of TERM information and batch them into a stacked representation. In the TERM information condensor, TERM information is initially stored by concatenating all TERM tensors side by side in one dimension. However, for message passing, it's convenient to batch these TERMs by splitting them and stacking them in a new dimension. Args ---- batched_flat_terms : torch.Tensor Tensor with shape :code:`(n_batch, sum_term_len, ...)` term_lens : list of (list of int) Length of each TERM per protein Returns ------- batchify_terms : torch.Tensor Tensor with shape :code:`(n_batch, max_num_terms, max_term_len, ...)` """ n_batches = batched_flat_terms.shape[0] flat_terms = torch.unbind(batched_flat_terms) list_terms = [torch.split(flat_terms[i], term_lens[i]) for i in range(n_batches)] padded_terms = [pad_sequence(terms) for terms in list_terms] padded_terms = [term.transpose(0, 1) for term in padded_terms] batchify_terms = pad_sequence_12(padded_terms) return batchify_terms
# gather and cat functions # struct level
[docs]def gather_edges(edges, neighbor_idx): """ Gather the edge features of the nearest neighbors. From https://github.com/jingraham/neurips19-graph-protein-design Args ---- edges : torch.Tensor The edge features in dense form Shape: n_batch x n_res x n_res x n_hidden neighbor_idx : torch.LongTensor kNN sparse edge indices Shape : n_batch x n_res x k Returns ------- edge_features : torch.Tensor The gathered edge features Shape : n_batch x n_res x k x n_hidden """ # Features [B,N,N,C] at Neighbor indices [B,N,K] => Neighbor features [B,N,K,C] neighbors = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, edges.size(-1)) edge_features = torch.gather(edges, 2, neighbors) return edge_features
[docs]def gather_nodes(nodes, neighbor_idx): """ Gather node features of nearest neighbors. From https://github.com/jingraham/neurips19-graph-protein-design Args ---- nodes : torch.Tensor The node features for all nodes Shape: n_batch x n_res x n_hidden neighbor_idx : torch.LongTensor kNN sparse edge indices Shape : n_batch x n_res x k Returns ------- neighbor_features : torch.Tensor The gathered neighbor node features Shape : n_batch x n_res x k x n_hidden """ # Features [B,N,C] at Neighbor indices [B,N,K] => [B,N,K,C] # Flatten and expand indices per batch [B,N,K] => [B,NK] => [B,NK,C] neighbors_flat = neighbor_idx.view((neighbor_idx.shape[0], -1)) neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1, -1, nodes.size(2)) # Gather and re-pack neighbor_features = torch.gather(nodes, 1, neighbors_flat) neighbor_features = neighbor_features.view(list(neighbor_idx.shape)[:3] + [-1]) return neighbor_features
[docs]def cat_neighbors_nodes(h_nodes, h_neighbors, E_idx): """ Concatenate node features onto the ends of gathered edge features given kNN sparse edge indices From https://github.com/jingraham/neurips19-graph-protein-design Args ---- h_nodes : torch.Tensor The node features for all nodes Shape: n_batch x n_res x n_hidden h_neighbors : torch.Tensor The gathered edge features Shape: n_batch x n_res x k x n_hidden E_idx : torch.LongTensor kNN sparse edge indices Shape : n_batch x n_res x k Returns ------- h_nn : torch.Tensor The gathered concatenated node and edge features Shape : n_batch x n_res x k x n_hidden """ h_nodes = gather_nodes(h_nodes, E_idx) h_nn = torch.cat([h_neighbors, h_nodes], -1) return h_nn
[docs]def cat_edge_endpoints(h_edges, h_nodes, E_idx): """ Concatenate both node features onto the ends of gathered edge features given kNN sparse edge indices Args ---- h_edges : torch.Tensor The gathered edge features Shape: n_batch x n_res x k x n_hidden h_nodes : torch.Tensor The node features for all nodes Shape: n_batch x n_res x n_hidden E_idx : torch.LongTensor kNN sparse edge indices Shape : n_batch x n_res x k Returns ------- h_nn : torch.Tensor The gathered concatenated node and edge features Shape : n_batch x n_res x k x n_hidden """ # Neighbor indices E_idx [B,N,K] # Edge features h_edges [B,N,N,C] # Node features h_nodes [B,N,C] k = E_idx.shape[-1] h_i_idx = E_idx[:, :, 0].unsqueeze(-1).expand(-1, -1, k).contiguous() h_j_idx = E_idx h_i = gather_nodes(h_nodes, h_i_idx) h_j = gather_nodes(h_nodes, h_j_idx) # output features [B, N, K, 3C] h_nn = torch.cat([h_i, h_j, h_edges], -1) return h_nn
[docs]def gather_pairEs(pairEs, neighbor_idx): """ Gather the pair energies features of the nearest neighbors. From https://github.com/jingraham/neurips19-graph-protein-design Args ---- pairEs : torch.Tensor The pair energies in dense form Shape: n_batch x n_res x n_res x n_aa x n_aa neighbor_idx : torch.LongTensor kNN sparse edge indices Shape : n_batch x n_res x k Returns ------- pairE_features : torch.Tensor The gathered pair energies Shape : n_batch x n_res x k x n_aa x n_aa """ n_aa = pairEs.size(-1) neighbors = neighbor_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, n_aa, n_aa) pairE_features = torch.gather(pairEs, 2, neighbors) return pairE_features
# term level
[docs]def gather_term_nodes(nodes, neighbor_idx): """ Gather TERM node features of nearest neighbors. Adatped from https://github.com/jingraham/neurips19-graph-protein-design Args ---- nodes : torch.Tensor The node features for all nodes Shape: n_batch x n_terms x n_res x n_hidden neighbor_idx : torch.LongTensor kNN sparse edge indices Shape : n_batch x n_terms x n_res x k Returns ------- neighbor_features : torch.Tensor The gathered neighbor node features Shape : n_batch x n_terms x n_res x k x n_hidden """ # Features [B,T,N,C] at Neighbor indices [B,T,N,K] => [B,T,N,K,C] # Flatten and expand indices per batch [B,T,N,K] => [B,T,NK] => [B,T,NK,C] neighbors_flat = neighbor_idx.view((neighbor_idx.shape[0], neighbor_idx.shape[1], -1)) neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1, -1, -1, nodes.size(3)) # Gather and re-pack neighbor_features = torch.gather(nodes, 2, neighbors_flat) neighbor_features = neighbor_features.view(list(neighbor_idx.shape)[:4] + [-1]) return neighbor_features
[docs]def gather_term_edges(edges, neighbor_idx): """ Gather the TERM edge features of the nearest neighbors. From https://github.com/jingraham/neurips19-graph-protein-design Args ---- edges : torch.Tensor The edge features in dense form Shape: n_batch x n_terms x n_res x n_res x n_hidden neighbor_idx : torch.LongTensor kNN sparse edge indices Shape : n_batch x n_terms x n_res x k Returns ------- edge_features : torch.Tensor The gathered edge features Shape : n_batch x n_terms x n_res x k x n_hidden """ # Features [B,T,N,N,C] at Neighbor indices [B,T,N,K] => Neighbor features [B,T,N,K,C] neighbors = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, -1, edges.size(-1)) edge_features = torch.gather(edges, 3, neighbors) return edge_features
[docs]def cat_term_neighbors_nodes(h_nodes, h_neighbors, E_idx): """ Concatenate node features onto the ends of gathered edge features given kNN sparse edge indices From https://github.com/jingraham/neurips19-graph-protein-design Args ---- h_nodes : torch.Tensor The node features for all nodes Shape: n_batch x n_terms x n_res x n_hidden h_neighbors : torch.Tensor The gathered edge features Shape: n_batch x n_terms x n_res x k x n_hidden E_idx : torch.LongTensor kNN sparse edge indices Shape : n_batch x n_terms x n_res x k Returns ------- h_nn : torch.Tensor The gathered concatenated node and edge features Shape : n_batch x n_terms x n_res x k x n_hidden """ h_nodes = gather_term_nodes(h_nodes, E_idx) h_nn = torch.cat([h_neighbors, h_nodes], -1) return h_nn
[docs]def cat_term_edge_endpoints(h_edges, h_nodes, E_idx): """ Concatenate both node features onto the ends of gathered edge features given kNN sparse edge indices Args ---- h_edges : torch.Tensor The gathered edge features Shape: n_batch x n_terms x n_res x k x n_hidden h_nodes : torch.Tensor The node features for all nodes Shape: n_batch x n_terms x n_res x n_hidden E_idx : torch.LongTensor kNN sparse edge indices Shape : n_batch x n_terms x n_res x k Returns ------- h_nn : torch.Tensor The gathered concatenated node and edge features Shape : n_batch x n_terms x n_res x k x n_hidden """ # Neighbor indices E_idx [B,T,N,K] # Edge features h_edges [B,T,N,N,C] # Node features h_nodes [B,T,N,C] k = E_idx.shape[-1] h_i_idx = E_idx[:, :, :, 0].unsqueeze(-1).expand(-1, -1, -1, k).contiguous() h_j_idx = E_idx h_i = gather_term_nodes(h_nodes, h_i_idx) h_j = gather_term_nodes(h_nodes, h_j_idx) # e_ij = gather_edges(h_edges, E_idx) e_ij = h_edges # output features [B, T, N, K, 3C] h_nn = torch.cat([h_i, h_j, e_ij], -1) return h_nn
# merge edge fns
[docs]def merge_duplicate_edges(h_E_update, E_idx): """ Average embeddings across bidirectional edges. TERMinator edges are represented as two bidirectional edges, and to allow for communication between these edges we average the embeddings. Args ---- h_E_update : torch.Tensor Update tensor for edges embeddings in kNN sparse form Shape : n_batch x n_res x k x n_hidden E_idx : torch.LongTensor kNN sparse edge indices Shape : n_batch x n_res x k Returns ------- merged_E_updates : torch.Tensor Edge update with merged updates for bidirectional edges Shape : n_batch x n_res x k x n_hidden """ seq_lens = torch.ones(h_E_update.shape[0]).long().to(h_E_update.device) * h_E_update.shape[1] h_dim = h_E_update.shape[-1] h_E_geometric = h_E_update.view([-1, h_dim]) split_E_idxs = torch.unbind(E_idx) offset = [seq_lens[:i].sum() for i in range(len(seq_lens))] split_E_idxs = [e.to(h_E_update.device) + o for e, o in zip(split_E_idxs, offset)] edge_index_row = torch.cat([e.view(-1) for e in split_E_idxs], dim=0) edge_index_col = torch.repeat_interleave(torch.arange(edge_index_row.shape[0] // 30), 30).to(h_E_update.device) edge_index = torch.stack([edge_index_row, edge_index_col]) merge = merge_duplicate_edges_geometric(h_E_geometric, edge_index) merge = merge.view(h_E_update.shape) # dev = h_E_update.device # n_batch, n_nodes, _, hidden_dim = h_E_update.shape # # collect edges into NxN tensor shape # collection = torch.zeros((n_batch, n_nodes, n_nodes, hidden_dim)).to(dev) # neighbor_idx = E_idx.unsqueeze(-1).expand(-1, -1, -1, hidden_dim).to(dev) # collection.scatter_(2, neighbor_idx, h_E_update) # # transpose to get same edge in reverse direction # collection = collection.transpose(1, 2) # # gather reverse edges # reverse_E_update = gather_edges(collection, E_idx) # # average h_E_update and reverse_E_update at non-zero positions # merged_E_updates = torch.where(reverse_E_update != 0, (h_E_update + reverse_E_update) / 2, h_E_update) # assert (merge == merged_E_updates).all() return merge
[docs]def merge_duplicate_edges_geometric(h_E_update, edge_index): """ Average embeddings across bidirectional edges for Torch Geometric graphs TERMinator edges are represented as two bidirectional edges, and to allow for communication between these edges we average the embeddings. Args ---- h_E_update : torch.Tensor Update tensor for edges embeddings in Torch Geometric sparse form Shape : n_edge x n_hidden edge_index : torch.LongTensor Torch Geometric sparse edge indices Shape : 2 x n_edge Returns ------- merged_E_updates : torch.Tensor Edge update with merged updates for bidirectional edges Shape : n_edge x n_hidden """ num_nodes = edge_index.max() + 1 row_idx = edge_index[0] + edge_index[1] * num_nodes col_idx = edge_index[1] + edge_index[0] * num_nodes internal_idx = torch.arange(edge_index.shape[1]) mapping = torch.zeros(max(row_idx.max(), col_idx.max()) + 1).long() - 1 mapping[col_idx] = internal_idx reverse_idx = mapping[row_idx] mask = (reverse_idx >= 0) reverse_idx = reverse_idx[mask] reverse_h_E = h_E_update[mask] h_E_update[reverse_idx] = (h_E_update[reverse_idx] + reverse_h_E)/2 return h_E_update
[docs]def merge_duplicate_term_edges(h_E_update, E_idx): """ Average embeddings across bidirectional TERM edges. TERMinator edges are represented as two bidirectional edges, and to allow for communication between these edges we average the embeddings. Args ---- h_E_update : torch.Tensor Update tensor for edges embeddings in kNN sparse form Shape : n_batch x n_terms x n_res x k x n_hidden E_idx : torch.LongTensor kNN sparse edge indices Shape : n_batch x n_terms x n_res x k Returns ------- merged_E_updates : torch.Tensor Edge update with merged updates for bidirectional edges Shape : n_batch x n_terms x n_res x k x n_hidden """ dev = h_E_update.device n_batch, n_terms, n_aa, _, hidden_dim = h_E_update.shape # collect edges into NxN tensor shape collection = torch.zeros((n_batch, n_terms, n_aa, n_aa, hidden_dim)).to(dev) neighbor_idx = E_idx.unsqueeze(-1).expand(-1, -1, -1, -1, hidden_dim).to(dev) collection.scatter_(3, neighbor_idx, h_E_update) # transpose to get same edge in reverse direction collection = collection.transpose(2, 3) # gather reverse edges reverse_E_update = gather_term_edges(collection, E_idx) # average h_E_update and reverse_E_update at non-zero positions merged_E_updates = torch.where(reverse_E_update != 0, (h_E_update + reverse_E_update) / 2, h_E_update) return merged_E_updates
[docs]def merge_duplicate_pairE(h_E, E_idx): """ Average pair energy tables across bidirectional edges. TERMinator edges are represented as two bidirectional edges, and to allow for communication between these edges we average the embeddings. In the case for pair energies, we transpose the tables to ensure that the pair energy table is symmetric upon inverse (e.g. the pair energy between i and j should be the same as the pair energy between j and i) Args ---- h_E : torch.Tensor Pair energies in kNN sparse form Shape : n_batch x n_res x k x n_aa x n_aa E_idx : torch.LongTensor kNN sparse edge indices Shape : n_batch x n_res x k Returns ------- torch.Tensor Pair energies with merged energies for bidirectional edges Shape : n_batch x n_res x k x n_aa x n_aa """ try: seq_lens = torch.ones(h_E.shape[0]).long().to(h_E.device) * h_E.shape[1] h_E_geometric = h_E.view([-1, 400]) split_E_idxs = torch.unbind(E_idx) offset = [seq_lens[:i].sum() for i in range(len(seq_lens))] split_E_idxs = [e.to(h_E.device) + o for e, o in zip(split_E_idxs, offset)] edge_index_row = torch.cat([e.view(-1) for e in split_E_idxs], dim=0) edge_index_col = torch.repeat_interleave(torch.arange(edge_index_row.shape[0] // 30), 30).to(h_E.device) edge_index = torch.stack([edge_index_row, edge_index_col]) merge = merge_duplicate_pairE_geometric(h_E_geometric, edge_index) merge = merge.view(h_E.shape) #old_merge = merge_duplicate_pairE_dense(h_E, E_idx) #assert (old_merge == merge).all(), (old_merge, merge) return merge except RuntimeError as err: print(err, file=sys.stderr) print("We're handling this error as if it's an out-of-memory error", file=sys.stderr) torch.cuda.empty_cache() # this is probably unnecessary but just in case return merge_duplicate_pairE_sparse(h_E, E_idx)
[docs]def merge_duplicate_pairE_dense(h_E, E_idx): """ Dense method to average pair energy tables across bidirectional edges. TERMinator edges are represented as two bidirectional edges, and to allow for communication between these edges we average the embeddings. In the case for pair energies, we transpose the tables to ensure that the pair energy table is symmetric upon inverse (e.g. the pair energy between i and j should be the same as the pair energy between j and i) Args ---- h_E : torch.Tensor Pair energies in kNN sparse form Shape : n_batch x n_res x k x n_aa x n_aa E_idx : torch.LongTensor kNN sparse edge indices Shape : n_batch x n_res x k Returns ------- torch.Tensor Pair energies with merged energies for bidirectional edges Shape : n_batch x n_res x k x n_aa x n_aa """ dev = h_E.device n_batch, n_nodes, _, n_aa, _ = h_E.shape # collect edges into NxN tensor shape collection = torch.zeros((n_batch, n_nodes, n_nodes, n_aa, n_aa)).to(dev) neighbor_idx = E_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, n_aa, n_aa).to(dev) collection.scatter_(2, neighbor_idx, h_E) # transpose to get same edge in reverse direction collection = collection.transpose(1, 2) # transpose each pair energy table as well collection = collection.transpose(-2, -1) # gather reverse edges reverse_E = gather_pairEs(collection, E_idx) # average h_E and reverse_E at non-zero positions merged_E = torch.where(reverse_E != 0, (h_E + reverse_E) / 2, h_E) return merged_E
# TODO: rigorous test that this is equiv to the dense version
[docs]def merge_duplicate_pairE_sparse(h_E, E_idx): """ Sparse method to average pair energy tables across bidirectional edges. Note: This method involves a significant slowdown so it's only worth using if memory is an issue. TERMinator edges are represented as two bidirectional edges, and to allow for communication between these edges we average the embeddings. In the case for pair energies, we transpose the tables to ensure that the pair energy table is symmetric upon inverse (e.g. the pair energy between i and j should be the same as the pair energy between j and i) Args ---- h_E : torch.Tensor Pair energies in kNN sparse form Shape : n_batch x n_res x k x n_aa x n_aa E_idx : torch.LongTensor kNN sparse edge indices Shape : n_batch x n_res x k Returns ------- torch.Tensor Pair energies with merged energies for bidirectional edges Shape : n_batch x n_res x k x n_aa x n_aa """ dev = h_E.device n_batch, n_nodes, k, n_aa, _ = h_E.shape # convert etab into a sparse etab # self idx of the edge ref_idx = E_idx[:, :, 0:1].expand(-1, -1, k) # sparse idx g_idx = torch.cat([E_idx.unsqueeze(1), ref_idx.unsqueeze(1)], dim=1) sparse_idx = g_idx.view([n_batch, 2, -1]) # generate a 1D idx for the forward and backward direction scaler = torch.ones_like(sparse_idx).to(dev) scaler = scaler * n_nodes scaler_f = scaler scaler_f[:, 0] = 1 scaler_r = torch.flip(scaler_f, [1]) batch_offset = torch.arange(n_batch).unsqueeze(-1).expand([-1, n_nodes * k]) * n_nodes * k batch_offset = batch_offset.to(dev) sparse_idx_f = torch.sum(scaler_f * sparse_idx, 1) + batch_offset flat_idx_f = sparse_idx_f.view([-1]) sparse_idx_r = torch.sum(scaler_r * sparse_idx, 1) + batch_offset flat_idx_r = sparse_idx_r.view([-1]) # generate sparse tensors flat_h_E_f = h_E.view([n_batch * n_nodes * k, n_aa**2]) reverse_h_E = h_E.transpose(-2, -1).contiguous() flat_h_E_r = reverse_h_E.view([n_batch * n_nodes * k, n_aa**2]) sparse_etab_f = torch.sparse_coo_tensor(flat_idx_f.unsqueeze(0), flat_h_E_f, (n_batch * n_nodes * n_nodes, n_aa**2)) count_f = torch.sparse_coo_tensor(flat_idx_f.unsqueeze(0), torch.ones_like(flat_idx_f), (n_batch * n_nodes * n_nodes, )) sparse_etab_r = torch.sparse_coo_tensor(flat_idx_r.unsqueeze(0), flat_h_E_r, (n_batch * n_nodes * n_nodes, n_aa**2)) count_r = torch.sparse_coo_tensor(flat_idx_r.unsqueeze(0), torch.ones_like(flat_idx_r), (n_batch * n_nodes * n_nodes, )) # merge sparse_etab = sparse_etab_f + sparse_etab_r sparse_etab = sparse_etab.coalesce() count = count_f + count_r count = count.coalesce() # this step is very slow, but implementing something faster is probably a lot of work # requires pytorch 1.10 to be fast enough to be usable collect = sparse_etab.index_select(0, flat_idx_f).to_dense() weight = count.index_select(0, flat_idx_f).to_dense() flat_merged_etab = collect / weight.unsqueeze(-1) merged_etab = flat_merged_etab.view(h_E.shape) return merged_etab
[docs]def merge_duplicate_pairE_geometric(h_E, edge_index): """ Sparse method to average pair energy tables across bidirectional edges with Torch Geometric. TERMinator edges are represented as two bidirectional edges, and to allow for communication between these edges we average the embeddings. In the case for pair energies, we transpose the tables to ensure that the pair energy table is symmetric upon inverse (e.g. the pair energy between i and j should be the same as the pair energy between j and i) This function assumes edge_index is sorted by columns, and will fail if this is not the case. Args ---- h_E : torch.Tensor Pair energies in Torch Geometric sparse form Shape : n_edge x 400 E_idx : torch.LongTensor Torch Geometric sparse edge indices Shape : 2 x n_edge Returns ------- torch.Tensor Pair energies with merged energies for bidirectional edges Shape : n_edge x 400 """ num_nodes = edge_index.max() + 1 row_idx = edge_index[0] + edge_index[1] * num_nodes col_idx = edge_index[1] + edge_index[0] * num_nodes internal_idx = torch.arange(edge_index.shape[1]) mapping = torch.zeros(max(row_idx.max(), col_idx.max()) + 1).long() - 1 mapping[col_idx] = internal_idx reverse_idx = mapping[row_idx] mask = (reverse_idx >= 0) reverse_idx = reverse_idx[mask] reverse_h_E = h_E[mask] transpose_h_E = reverse_h_E.view([-1, 20, 20]).transpose(-1, -2).reshape([-1, 400]) h_E[reverse_idx] = (h_E[reverse_idx] + transpose_h_E)/2 return h_E
# edge aggregation fns
[docs]def aggregate_edges(edge_embeddings, E_idx, max_seq_len): """ Aggregate TERM edge embeddings into a sequence-level dense edge features tensor Args ---- edge_embeddings : torch.Tensor TERM edge features tensor Shape : n_batch x n_terms x n_aa x n_neighbors x n_hidden E_idx : torch.LongTensor TERM edge indices Shape : n_batch x n_terms x n_aa x n_neighbors max_seq_len : int Max length of a sequence in the batch Returns ------- torch.Tensor Dense sequence-level edge features Shape : n_batch x max_seq_len x max_seq_len x n_hidden """ dev = edge_embeddings.device n_batch, _, _, n_neighbors, hidden_dim = edge_embeddings.shape # collect edges into NxN tensor shape collection = torch.zeros((n_batch, max_seq_len, max_seq_len, hidden_dim)).to(dev) # edge the edge indecies self_idx = E_idx[:, :, :, 0].unsqueeze(-1).expand(-1, -1, -1, n_neighbors) neighbor_idx = E_idx # tensor needed for accumulation layer = torch.arange(n_batch).view([n_batch, 1, 1, 1]).expand(neighbor_idx.shape).to(dev) # thicc index_put_ collection.index_put_((layer, self_idx, neighbor_idx), edge_embeddings, accumulate=True) # we also need counts for averaging count = torch.zeros((n_batch, max_seq_len, max_seq_len)).to(dev) count_idx = torch.ones_like(neighbor_idx).float().to(dev) count.index_put_((layer, self_idx, neighbor_idx), count_idx, accumulate=True) # we need to set all 0s to 1s so we dont get nans count[count == 0] = 1 return collection / count.unsqueeze(-1)