terminator.data.data.TERMBatchSampler¶
- class terminator.data.data.TERMBatchSampler(dataset, batch_size=4, sort_data=False, shuffle=True, semi_shuffle=False, semi_shuffle_cluster_size=500, batch_shuffle=True, drop_last=False, max_term_res=55000, max_seq_tokens=None)[source]¶
Bases:
SamplerBatchSampler/Dataloader helper class for TERM data using TERMDataset.
- Variables:
size (int) – Length of the dataset
dataset (List) – List of features from TERM dataset
total_term_lengths (List) – List of TERM lengths from the given dataset
seq_lengths (List) – List of sequence lengths from the given dataset
lengths (List) – TERM lengths or sequence lengths, depending on whether
max_term_resormax_seq_tokensis set.batch_size (int or None, default=4) – Size of batches created. If variable sized batches are desired, set to None.
sort_data (bool, default=False) – Create deterministic batches by sorting the data according to the specified length metric and creating batches from the sorted data. Incompatible with
shuffle=Trueandsemi_shuffle=True.shuffle (bool, default=True) – Shuffle the data completely before creating batches. Incompatible with
sort_data=Trueandsemi_shuffle=True.semi_shuffle (bool, default=False) – Sort the data according to the specified length metric, then partition the data into
semi_shuffle_cluster_size-sized partitions. Within each partition perform a complete shuffle. The upside is that batching with similar lengths reduces padding making for more efficient computation, but the downside is that it does a less complete shuffle.semi_shuffle_cluster_size (int, default=500) – Size of partition to use when
semi_shuffle=True.batch_shuffle (bool, default=True) – If set to
True, shuffle samples within a batch.drop_last (bool, default=False) – If set to
True, drop the last samples if they don’t form a complete batch.max_term_res (int or None, default=55000) – When
batch_size=None, max_term_res>0, max_seq_tokens=None, batch by fitting as many datapoints as possible with the total number of TERM residues included below max_term_res. Calibrated usingnn.DataParallelon two V100 GPUs.max_seq_tokens (int or None, default=None) – When
batch_size=None, max_term_res=None, max_seq_tokens>0, batch by fitting as many datapoints as possible with the total number of sequence residues included below max_seq_tokens.
- __init__(dataset, batch_size=4, sort_data=False, shuffle=True, semi_shuffle=False, semi_shuffle_cluster_size=500, batch_shuffle=True, drop_last=False, max_term_res=55000, max_seq_tokens=None)[source]¶
Reads in and processes a given dataset.
Given the provided dataset, load all the data. Then cluster the data using the provided method, either shuffled or sorted and then shuffled.
- Parameters:
dataset (TERMDataset) – Dataset to batch.
batch_size (int or None, default=4) – Size of batches created. If variable sized batches are desired, set to None.
sort_data (bool, default=False) – Create deterministic batches by sorting the data according to the specified length metric and creating batches from the sorted data. Incompatible with
shuffle=Trueandsemi_shuffle=True.shuffle (bool, default=True) – Shuffle the data completely before creating batches. Incompatible with
sort_data=Trueandsemi_shuffle=True.semi_shuffle (bool, default=False) – Sort the data according to the specified length metric, then partition the data into
semi_shuffle_cluster_size-sized partitions. Within each partition perform a complete shuffle. The upside is that batching with similar lengths reduces padding making for more efficient computation, but the downside is that it does a less complete shuffle.semi_shuffle_cluster_size (int, default=500) – Size of partition to use when
semi_shuffle=True.batch_shuffle (bool, default=True) – If set to
True, shuffle samples within a batch.drop_last (bool, default=False) – If set to
True, drop the last samples if they don’t form a complete batch.max_term_res (int or None, default=55000) – When
batch_size=None, max_term_res>0, max_seq_tokens=None, batch by fitting as many datapoints as possible with the total number of TERM residues included below max_term_res. Calibrated usingnn.DataParallelon two V100 GPUs.max_seq_tokens (int or None, default=None) – When
batch_size=None, max_term_res=None, max_seq_tokens>0, batch by fitting as many datapoints as possible with the total number of sequence residues included below max_seq_tokens. Exactly one ofmax_term_resandmax_seq_tokensmust be None.
Methods
__init__(dataset[, batch_size, sort_data, ...])Reads in and processes a given dataset.
package(b_idx)Package the given datapoints into tensors based on provided indices.
- _cluster()[source]¶
Shuffle data and make clusters of indices corresponding to batches of data.
This method speeds up training by sorting data points with similar TERM lengths together, if
sort_dataorsemi_shuffleare on. Under sort_data, the data is sorted by length. Under semi_shuffle, the data is broken up into clusters based on length and shuffled within the clusters. Otherwise, it is randomly shuffled. Data is then loaded into batches based on the number of proteins that will fit into the GPU without overloading it, based onmax_term_resormax_seq_tokens.
- package(b_idx)[source]¶
Package the given datapoints into tensors based on provided indices.
Tensors are extracted from the data and padded. Coordinates are featurized and the length of TERMs and chain IDs are added to the data.
- Parameters:
b_idx (list of tuples (dicts, int, int)) – The feature dictionaries, the sum of the lengths of all TERMs, and the sum of all sequence lengths for each datapoint to package.
- Returns:
Collection of batched features required for running TERMinator. This contains:
msas- the sequences for each TERM match to the target structurefeatures- the \(\phi, \psi, \omega\), and environment values of the TERM matchesppoe- the \(\phi, \psi, \omega\), and environment values of the target structureseq_lens- lengths of the target sequencesfocuses- the corresponding target structure residue index for each TERM residuecontact_idxs- contact indices for each TERM residuesrc_key_mask- mask for TERM residue paddingX- coordinatesx_mask- mask for the target structureseqs- the target sequencesids- the PDB idschain_idx- the chain IDs
- Return type:
dict