Source code for terminator.models.layers.energies.s2s

""" GNN Potts Model Encoder modules

This file contains the GNN Potts Model Encoder, as well as an ablated version of
itself. """
from __future__ import print_function

import torch
from torch import nn

from terminator.models.layers.graph_features import MultiChainProteinFeatures
from terminator.models.layers.s2s_modules import (EdgeMPNNLayer, EdgeTransformerLayer, NodeMPNNLayer,
                                                  NodeTransformerLayer)
from terminator.models.layers.utils import (cat_edge_endpoints, cat_neighbors_nodes, gather_edges, gather_nodes,
                                            merge_duplicate_pairE)

# pylint: disable=no-member, not-callable


[docs]class AblatedPairEnergies(nn.Module): """Ablated GNN Potts Model Encoder Attributes ---------- dev: str Device representing where the model is held hparams: dict Dictionary of parameter settings (see :code:`terminator/utils/model/default_hparams.py`) features : MultiChainProteinFeatures Module that featurizes a protein backbone (including multimeric proteins) W : nn.Linear Output layer that projects edge embeddings to proper output dimensionality """
[docs] def __init__(self, hparams): """ Graph labeling network """ super().__init__() hdim = hparams['energies_hidden_dim'] self.hparams = hparams # Featurization layers self.features = MultiChainProteinFeatures(node_features=hdim, edge_features=hdim, top_k=hparams['k_neighbors'], features_type=hparams['energies_protein_features'], augment_eps=hparams['energies_augment_eps'], dropout=hparams['energies_dropout']) self.W = nn.Linear(hparams['energies_input_dim'] * 3, hparams['energies_output_dim'])
[docs] def forward(self, V_embed, E_embed, X, x_mask, chain_idx): """ Create kNN etab from TERM features, then project to proper output dimensionality. Args ---- V_embed : torch.Tensor TERM node embeddings Shape: n_batch x n_res x n_hidden E_embed : torch.Tensor TERM edge embeddings Shape : n_batch x n_res x n_res x n_hidden X : torch.Tensor Backbone coordinates Shape: n_batch x n_res x 4 x 3 x_mask : torch.ByteTensor Mask for X. Shape: n_batch x n_res chain_idx : torch.LongTensor Indices such that each chain is assigned a unique integer and each residue in that chain is assigned that integer. Shape: n_batch x n_res Returns ------- etab : torch.Tensor Energy table in kNN dense form Shape: n_batch x n_res x k x n_hidden E_idx : torch.LongTensor Edge index for `etab` Shape: n_batch x n_res x k """ # compute the kNN etab _, _, E_idx = self.features(X, chain_idx, x_mask) # notably, we throw away the backbone features E_embed_neighbors = gather_edges(E_embed, E_idx) h_E = cat_edge_endpoints(E_embed_neighbors, V_embed, E_idx) etab = self.W(h_E) # merge duplicate pairEs n_batch, n_res, k, out_dim = etab.shape # ensure output etab is masked properly etab = etab * x_mask.view(n_batch, n_res, 1, 1) etab = etab.unsqueeze(-1).view(n_batch, n_res, k, 20, 20) etab[:, :, 0] = etab[:, :, 0] * torch.eye(20).to(etab.device) # zero off-diagonal energies etab = merge_duplicate_pairE(etab, E_idx) etab = etab.view(n_batch, n_res, k, out_dim) return etab, E_idx
[docs]class PairEnergies(nn.Module): """GNN Potts Model Encoder Attributes ---------- dev: str Device representing where the model is held hparams: dict Dictionary of parameter settings (see :code:`terminator/utils/model/default_hparams.py`) features : MultiChainProteinFeatures Module that featurizes a protein backbone (including multimeric proteins) W_v : nn.Linear Embedding layer for incoming TERM node embeddings W_e : nn.Linear Embedding layer for incoming TERM edge embeddings edge_encoder : nn.ModuleList of EdgeTransformerLayer or EdgeMPNNLayer Edge graph update layers node_encoder : nn.ModuleList of NodeTransformerLayer or NodeMPNNLayer Node graph update layers W_out : nn.Linear Output layer that projects edge embeddings to proper output dimensionality W_proj : nn.Linear (optional) Output layer that projects node embeddings to proper output dimensionality. Enabled when :code:`hparams["node_self_sub"]=True` """
[docs] def __init__(self, hparams): """ Graph labeling network """ super().__init__() self.hparams = hparams hdim = hparams['energies_hidden_dim'] # Hyperparameters self.node_features = hdim self.edge_features = hdim self.input_dim = hdim hidden_dim = hdim output_dim = hparams['energies_output_dim'] dropout = hparams['energies_dropout'] num_encoder_layers = hparams['energies_encoder_layers'] # Featurization layers self.features = MultiChainProteinFeatures(node_features=hdim, edge_features=hdim, top_k=hparams['k_neighbors'], features_type=hparams['energies_protein_features'], augment_eps=hparams['energies_augment_eps'], dropout=hparams['energies_dropout']) # Embedding layers self.W_v = nn.Linear(hdim + hparams['energies_input_dim'], hdim, bias=True) self.W_e = nn.Linear(hdim + hparams['energies_input_dim'], hdim, bias=True) edge_layer = EdgeTransformerLayer if not hparams['energies_use_mpnn'] else EdgeMPNNLayer node_layer = NodeTransformerLayer if not hparams['energies_use_mpnn'] else NodeMPNNLayer # Encoder layers self.edge_encoder = nn.ModuleList( [edge_layer(hidden_dim, hidden_dim * 3, dropout=dropout) for _ in range(num_encoder_layers)]) self.node_encoder = nn.ModuleList( [node_layer(hidden_dim, hidden_dim * 2, dropout=dropout) for _ in range(num_encoder_layers)]) # if enabled, generate self energies in etab from node embeddings if "node_self_sub" in hparams.keys() and hparams["node_self_sub"] is True: self.W_proj = nn.Linear(hidden_dim, 20) # project edges to proper output dimensionality self.W_out = nn.Linear(hidden_dim, output_dim, bias=True) # Initialization for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p)
[docs] def forward(self, V_embed, E_embed, X, x_mask, chain_idx): """ Create kNN etab from backbone and TERM features, then project to proper output dimensionality. Args ---- V_embed : torch.Tensor or None TERM node embeddings. None only accepted if :code:`hparams['energies_input_dim']=0`. Shape: n_batch x n_res x n_hidden E_embed : torch.Tensor or None TERM edge embeddings. None only accepted if :code:`hparams['energies_input_dim']=0`. Shape : n_batch x n_res x n_res x n_hidden X : torch.Tensor Backbone coordinates Shape: n_batch x n_res x 4 x 3 x_mask : torch.ByteTensor Mask for X. Shape: n_batch x n_res chain_idx : torch.LongTensor Indices such that each chain is assigned a unique integer and each residue in that chain is assigned that integer. Shape: n_batch x n_res Returns ------- etab : torch.Tensor Energy table in kNN dense form Shape: n_batch x n_res x k x n_hidden E_idx : torch.LongTensor Edge index for `etab` Shape: n_batch x n_res x k """ # Prepare node and edge embeddings if self.hparams['energies_input_dim'] != 0: V, E, E_idx = self.features(X, chain_idx, x_mask) if not self.hparams['use_coords']: # this is hacky/inefficient but i am lazy V = torch.zeros_like(V) E = torch.zeros_like(E) # fuse backbone and TERM embeddings h_V = self.W_v(torch.cat([V, V_embed], dim=-1)) E_embed_neighbors = gather_edges(E_embed, E_idx) h_E = self.W_e(torch.cat([E, E_embed_neighbors], dim=-1)) else: # just use backbone features V, E, E_idx = self.features(X, chain_idx, x_mask) h_V = self.W_v(V) h_E = self.W_e(E) # Graph updates mask_attend = gather_nodes(x_mask.unsqueeze(-1), E_idx).squeeze(-1) mask_attend = x_mask.unsqueeze(-1) * mask_attend for edge_layer, node_layer in zip(self.edge_encoder, self.node_encoder): h_EV_edges = cat_edge_endpoints(h_E, h_V, E_idx) h_E = edge_layer(h_E, h_EV_edges, E_idx, mask_E=x_mask, mask_attend=mask_attend) h_EV_nodes = cat_neighbors_nodes(h_V, h_E, E_idx) h_V = node_layer(h_V, h_EV_nodes, mask_V=x_mask, mask_attend=mask_attend) # project to output and merge duplicate pairEs h_E = self.W_out(h_E) n_batch, n_res, k, out_dim = h_E.shape h_E = h_E * x_mask.view(n_batch, n_res, 1, 1) # ensure output etab is masked properly h_E = h_E.unsqueeze(-1).view(n_batch, n_res, k, 20, 20) h_E[:, :, 0] = h_E[:, :, 0] * torch.eye(20).to(h_E.device) # zero off-diagonal energies h_E = merge_duplicate_pairE(h_E, E_idx) # if specified, use generate self energies from node embeddings if "node_self_sub" in self.hparams.keys() and self.hparams["node_self_sub"] is True: h_V = self.W_proj(h_V) h_E[..., 0, :, :] = torch.diag_embed(h_V, dim1=-2, dim2=-1) # reshape to fit kNN output format h_E = h_E.view(n_batch, n_res, k, out_dim) return h_E, E_idx