Source code for scripts.data.postprocessing.to_etab

"""Parse output of TERMinator into :code:`.etab` files for use in MST.

Usage:
    .. code-block::

        python to_etab.py \\
            --output_dir <folder_with_net.out> \\
            --dtermen_data <dtermen_data_root> \\
            --num_cores <num_processes> \\
            [-u]

See :code:`python to_etab.py --help` for more info.
"""
import argparse
import multiprocessing as mp
import os
import pickle
import sys
import time
import traceback

import numpy as np
from tqdm import tqdm

from terminator.utils.common import int_to_3lt_AA

# pylint: disable=wrong-import-position,wrong-import-order,redefined-outer-name,unspecified-encoding

# for autosummary import purposes
sys.path.insert(0, os.path.dirname(__file__))
from search_utils import find_dtermen_folder


# print to stderr
[docs]def eprint(*args, **kwargs): """Print to stderr rather than stdout""" print(*args, file=sys.stderr, **kwargs)
# pylint: disable=broad-except def _to_etab_file_wrapper(etab_matrix, E_idx, idx_dict, out_path): """Wrapper for _to_etab_file that does error handling""" try: return to_etab_file(etab_matrix, E_idx, idx_dict, out_path) except Exception: eprint(out_path) eprint(idx_dict) traceback.print_exc() return False, out_path # should work for multi-chain proteins now
[docs]def to_etab_file(etab_matrix, E_idx, idx_dict, out_path): """Write an :code:`.etab` file based on the fed in matrix and other indexing factors. Args ==== etab_matrix : np.ndarray Etab outputted by TERMinator E_idx : np.ndarray Indexing matrix associated with :code:`etab_matrix` idx_dict : dict Index conversion dictionary outputted by :code:`get_idx_dict` out_path : str Path to write the etab to Returns ======= bool Whether or not the parsing occured without errors out_path : str The output path fed in """ out_file = open(out_path, 'w') # etab matrix: l x k x 20 x 20 self_etab = etab_matrix[:, 0] pair_etab = etab_matrix[:, 1:] E_idx = E_idx[:, 1:] # l x 20 self_nrgs = np.diagonal(self_etab, offset=0, axis1=-2, axis2=-1) for aa_idx, aa_nrgs in enumerate(self_nrgs): # pylint: disable=broad-except try: chain, resid = idx_dict[aa_idx] except Exception: eprint("num residues: ", len(self_nrgs)) eprint(out_path) eprint(idx_dict) traceback.print_exc() return False, out_path for aa_int_id, nrg in enumerate(aa_nrgs): aa_3lt_id = int_to_3lt_AA[aa_int_id] out_file.write('{},{} {} {}\n'.format(chain, resid, aa_3lt_id, nrg)) pair_nrgs = {} # l x k-1 x 20 x 20 for i_idx, nrg_slice in enumerate(pair_etab): for k, k_slice in enumerate(nrg_slice): j_idx = E_idx[i_idx][k] chain_i, i_resid = idx_dict[i_idx] chain_j, j_resid = idx_dict[j_idx] for i, i_slice in enumerate(k_slice): i_3lt_id = int_to_3lt_AA[i] for j, nrg in enumerate(i_slice): j_3lt_id = int_to_3lt_AA[j] # every etab has two entries i, j and j, i # average these nrgs key = [(chain_i, i_resid, i_3lt_id), (chain_j, j_resid, j_3lt_id)] key.sort(key=lambda x: x[1]) key = tuple(key) if key not in pair_nrgs.keys(): pair_nrgs[key] = nrg else: current_nrg = pair_nrgs[key] pair_nrgs[key] = (current_nrg + nrg) / 2 for key, nrg in sorted(pair_nrgs.items(), key=lambda pair: pair[0][0][1]): chain_i, i_resid, i_3lt_id = key[0] chain_j, j_resid, j_3lt_id = key[1] out_file.write('{},{} {},{} {} {} {}\n'.format(chain_i, i_resid, chain_j, j_resid, i_3lt_id, j_3lt_id, nrg)) out_file.close() return True, out_path
[docs]def get_idx_dict(pdb, chain_filter=None): """From a :code:`.red.pdb` file, generate a dictionary mapping indices used within TERMinator to indices used by the :code:`.red.pdb` file. Args ==== pdb : str path to :code:`.red.pdb` file chain_filter : list of str or None only parse chains from :code:`chain_filter`. If :code:`None`, parse all chains Returns ======= idx_dict : dict Dictionary mapping indices used within TERMinator to indices used by the :code:`.red.pdb` file. """ idx_dict = {} with open(pdb, 'r') as fp: current_idx = 0 for line in fp: data = line.strip() if data == 'TER' or data == 'END': continue try: chain = data[21] # residx = int(data[22:26].strip()) # icode = data[26] # if icode != ' ': # residx = str(residx) + icode residx = data[22:27].strip() # rip i didn't know about icodes except Exception as e: print(data) raise e if chain_filter: if chain not in chain_filter: continue if (chain, residx) not in idx_dict.values(): idx_dict[current_idx] = (chain, residx) current_idx += 1 return idx_dict
if __name__ == '__main__': parser = argparse.ArgumentParser('Generate etabs') parser.add_argument('--output_dir', help='output directory', required=True) parser.add_argument("--dtermen_data", help="Root directory for all dTERMen runs", required=True) parser.add_argument('--num_cores', help='number of processes for parallelization', default=1) parser.add_argument('-u', dest='update', help='flag for force updating etabs', default=False, action='store_true') args = parser.parse_args() if not os.path.isdir(os.path.join(args.output_dir, 'etabs')): os.mkdir(os.path.join(args.output_dir, 'etabs')) print("made etabs dir") with open(os.path.join(args.output_dir, 'net.out'), 'rb') as fp: dump = pickle.load(fp) print("loaded dump") pool = mp.Pool(int(args.num_cores)) start = time.time() pbar = tqdm(total=len(dump)) not_worked = [] print("starting etab dump") for data in dump: pdb = data['ids'][0] E_idx = data['idx'][0].copy() etab = data['out'][0].copy() print(pdb) pdb_path = find_dtermen_folder(pdb, args.dtermen_data) idx_dict = get_idx_dict(os.path.join(pdb_path, f'{pdb}.red.pdb')) out_path = os.path.join(args.output_dir, 'etabs/' + pdb + '.etab') if os.path.exists(out_path) and not args.update: print(f"{pdb} already exists, skipping") pbar.update() continue def check_worked(res): """Update progress bar per iteration""" worked, out_path = res pbar.update() if not worked: not_worked.append(out_path) def raise_error(error): """Propogate error upwards""" raise error res = pool.apply_async(_to_etab_file_wrapper, args=(etab, E_idx, idx_dict, out_path), callback=check_worked, error_callback=raise_error) pool.close() pool.join() pbar.close() print(f"errors in {not_worked}") for path in not_worked: os.remove(path) end = time.time() print(f"done, took {end - start} seconds")