terminator.data.data.TERMBatchSampler

class terminator.data.data.TERMBatchSampler(dataset, batch_size=4, sort_data=False, shuffle=True, semi_shuffle=False, semi_shuffle_cluster_size=500, batch_shuffle=True, drop_last=False, max_term_res=55000, max_seq_tokens=None)[source]

Bases: Sampler

BatchSampler/Dataloader helper class for TERM data using TERMDataset.

Variables:
  • size (int) – Length of the dataset

  • dataset (List) – List of features from TERM dataset

  • total_term_lengths (List) – List of TERM lengths from the given dataset

  • seq_lengths (List) – List of sequence lengths from the given dataset

  • lengths (List) – TERM lengths or sequence lengths, depending on whether max_term_res or max_seq_tokens is set.

  • batch_size (int or None, default=4) – Size of batches created. If variable sized batches are desired, set to None.

  • sort_data (bool, default=False) – Create deterministic batches by sorting the data according to the specified length metric and creating batches from the sorted data. Incompatible with shuffle=True and semi_shuffle=True.

  • shuffle (bool, default=True) – Shuffle the data completely before creating batches. Incompatible with sort_data=True and semi_shuffle=True.

  • semi_shuffle (bool, default=False) – Sort the data according to the specified length metric, then partition the data into semi_shuffle_cluster_size-sized partitions. Within each partition perform a complete shuffle. The upside is that batching with similar lengths reduces padding making for more efficient computation, but the downside is that it does a less complete shuffle.

  • semi_shuffle_cluster_size (int, default=500) – Size of partition to use when semi_shuffle=True.

  • batch_shuffle (bool, default=True) – If set to True, shuffle samples within a batch.

  • drop_last (bool, default=False) – If set to True, drop the last samples if they don’t form a complete batch.

  • max_term_res (int or None, default=55000) – When batch_size=None, max_term_res>0, max_seq_tokens=None, batch by fitting as many datapoints as possible with the total number of TERM residues included below max_term_res. Calibrated using nn.DataParallel on two V100 GPUs.

  • max_seq_tokens (int or None, default=None) – When batch_size=None, max_term_res=None, max_seq_tokens>0, batch by fitting as many datapoints as possible with the total number of sequence residues included below max_seq_tokens.

__init__(dataset, batch_size=4, sort_data=False, shuffle=True, semi_shuffle=False, semi_shuffle_cluster_size=500, batch_shuffle=True, drop_last=False, max_term_res=55000, max_seq_tokens=None)[source]

Reads in and processes a given dataset.

Given the provided dataset, load all the data. Then cluster the data using the provided method, either shuffled or sorted and then shuffled.

Parameters:
  • dataset (TERMDataset) – Dataset to batch.

  • batch_size (int or None, default=4) – Size of batches created. If variable sized batches are desired, set to None.

  • sort_data (bool, default=False) – Create deterministic batches by sorting the data according to the specified length metric and creating batches from the sorted data. Incompatible with shuffle=True and semi_shuffle=True.

  • shuffle (bool, default=True) – Shuffle the data completely before creating batches. Incompatible with sort_data=True and semi_shuffle=True.

  • semi_shuffle (bool, default=False) – Sort the data according to the specified length metric, then partition the data into semi_shuffle_cluster_size-sized partitions. Within each partition perform a complete shuffle. The upside is that batching with similar lengths reduces padding making for more efficient computation, but the downside is that it does a less complete shuffle.

  • semi_shuffle_cluster_size (int, default=500) – Size of partition to use when semi_shuffle=True.

  • batch_shuffle (bool, default=True) – If set to True, shuffle samples within a batch.

  • drop_last (bool, default=False) – If set to True, drop the last samples if they don’t form a complete batch.

  • max_term_res (int or None, default=55000) – When batch_size=None, max_term_res>0, max_seq_tokens=None, batch by fitting as many datapoints as possible with the total number of TERM residues included below max_term_res. Calibrated using nn.DataParallel on two V100 GPUs.

  • max_seq_tokens (int or None, default=None) – When batch_size=None, max_term_res=None, max_seq_tokens>0, batch by fitting as many datapoints as possible with the total number of sequence residues included below max_seq_tokens. Exactly one of max_term_res and max_seq_tokens must be None.

Methods

__init__(dataset[, batch_size, sort_data, ...])

Reads in and processes a given dataset.

package(b_idx)

Package the given datapoints into tensors based on provided indices.

_cluster()[source]

Shuffle data and make clusters of indices corresponding to batches of data.

This method speeds up training by sorting data points with similar TERM lengths together, if sort_data or semi_shuffle are on. Under sort_data, the data is sorted by length. Under semi_shuffle, the data is broken up into clusters based on length and shuffled within the clusters. Otherwise, it is randomly shuffled. Data is then loaded into batches based on the number of proteins that will fit into the GPU without overloading it, based on max_term_res or max_seq_tokens.

package(b_idx)[source]

Package the given datapoints into tensors based on provided indices.

Tensors are extracted from the data and padded. Coordinates are featurized and the length of TERMs and chain IDs are added to the data.

Parameters:

b_idx (list of tuples (dicts, int, int)) – The feature dictionaries, the sum of the lengths of all TERMs, and the sum of all sequence lengths for each datapoint to package.

Returns:

Collection of batched features required for running TERMinator. This contains:

  • msas - the sequences for each TERM match to the target structure

  • features - the \(\phi, \psi, \omega\), and environment values of the TERM matches

  • ppoe - the \(\phi, \psi, \omega\), and environment values of the target structure

  • seq_lens - lengths of the target sequences

  • focuses - the corresponding target structure residue index for each TERM residue

  • contact_idxs - contact indices for each TERM residue

  • src_key_mask - mask for TERM residue padding

  • X - coordinates

  • x_mask - mask for the target structure

  • seqs - the target sequences

  • ids - the PDB ids

  • chain_idx - the chain IDs

Return type:

dict