terminator.models.layers.utils

Util functions useful in TERMinator modules

Functions

aggregate_edges(edge_embeddings, E_idx, ...)

Aggregate TERM edge embeddings into a sequence-level dense edge features tensor

batchify(batched_flat_terms, term_lens)

Take a flat representation of TERM information and batch them into a stacked representation.

cat_edge_endpoints(h_edges, h_nodes, E_idx)

Concatenate both node features onto the ends of gathered edge features given kNN sparse edge indices

cat_neighbors_nodes(h_nodes, h_neighbors, E_idx)

Concatenate node features onto the ends of gathered edge features given kNN sparse edge indices

cat_term_edge_endpoints(h_edges, h_nodes, E_idx)

Concatenate both node features onto the ends of gathered edge features given kNN sparse edge indices

cat_term_neighbors_nodes(h_nodes, ...)

Concatenate node features onto the ends of gathered edge features given kNN sparse edge indices

gather_edges(edges, neighbor_idx)

Gather the edge features of the nearest neighbors.

gather_nodes(nodes, neighbor_idx)

Gather node features of nearest neighbors.

gather_pairEs(pairEs, neighbor_idx)

Gather the pair energies features of the nearest neighbors.

gather_term_edges(edges, neighbor_idx)

Gather the TERM edge features of the nearest neighbors.

gather_term_nodes(nodes, neighbor_idx)

Gather TERM node features of nearest neighbors.

merge_duplicate_edges(h_E_update, E_idx)

Average embeddings across bidirectional edges.

merge_duplicate_edges_geometric(h_E_update, ...)

Average embeddings across bidirectional edges for Torch Geometric graphs

merge_duplicate_pairE(h_E, E_idx)

Average pair energy tables across bidirectional edges.

merge_duplicate_pairE_dense(h_E, E_idx)

Dense method to average pair energy tables across bidirectional edges.

merge_duplicate_pairE_geometric(h_E, edge_index)

Sparse method to average pair energy tables across bidirectional edges with Torch Geometric.

merge_duplicate_pairE_sparse(h_E, E_idx)

Sparse method to average pair energy tables across bidirectional edges.

merge_duplicate_term_edges(h_E_update, E_idx)

Average embeddings across bidirectional TERM edges.

pad_sequence_12(sequences[, padding_value])

Given a sequence of tensors, batch them together by pads both dims 1 and 2 to max length.