"""Datasets and dataloaders for loading TERMs.
This file contains dataset and dataloader classes
to be used when interacting with TERMs.
"""
import glob
import math
import multiprocessing as mp
import os
import pickle
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch_cluster
import torch_geometric
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, Sampler
from tqdm import tqdm
# pylint: disable=no-member, not-callable
# Jing featurization functions
def _normalize(tensor, dim=-1):
'''Normalizes a `torch.Tensor` along dimension `dim` without `nan`s.'''
return torch.nan_to_num(torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True)))
def _rbf(D, D_min=0., D_max=20., D_count=16, device='cpu'):
'''Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1.
That is, if `D` has shape [...dims], then the returned tensor will have
shape [...dims, D_count].
From https://github.com/jingraham/neurips19-graph-protein-design
'''
D_mu = torch.linspace(D_min, D_max, D_count, device=device)
D_mu = D_mu.view([1, -1])
D_sigma = (D_max - D_min) / D_count
D_expand = torch.unsqueeze(D, -1)
rbf = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
return rbf
def _dihedrals(X, eps=1e-7):
""" Compute dihedral angles between residues given atomic backbone coordinates
Args
----
X : torch.FloatTensor
Tensor specifying atomic backbone coordinates
Shape: num_res x 4 x 3
Returns
-------
D_features : torch.FloatTensor
Dihedral angles, lifted to the 3-torus
Shape: num_res x 7
"""
# From https://github.com/jingraham/neurips19-graph-protein-design
X = torch.reshape(X[:, :3], [3 * X.shape[0], 3])
dX = X[1:] - X[:-1]
U = _normalize(dX, dim=-1)
u_2 = U[:-2]
u_1 = U[1:-1]
u_0 = U[2:]
# Backbone normals
n_2 = _normalize(torch.cross(u_2, u_1), dim=-1)
n_1 = _normalize(torch.cross(u_1, u_0), dim=-1)
# Angle between normals
cosD = torch.sum(n_2 * n_1, -1)
cosD = torch.clamp(cosD, -1 + eps, 1 - eps)
D = torch.sign(torch.sum(u_2 * n_1, -1)) * torch.acos(cosD)
# This scheme will remove phi[0], psi[-1], omega[-1]
D = F.pad(D, [1, 2])
D = torch.reshape(D, [-1, 3])
# Lift angle representations to the circle
D_features = torch.cat([torch.cos(D), torch.sin(D)], 1)
return D_features
def _positional_embeddings(edge_index, num_embeddings=16, dev='cpu'):
""" Sinusoidally encode sequence distances for edges.
Args
----
edge_index : torch.LongTensor
Edge indices for sparse representation of protein graph
Shape: 2 x num_edges
num_embeddings : int or None, default=128
Dimensionality of sinusoidal embedding.
Returns
-------
E : torch.FloatTensor
Sinusoidal encoding of sequence distances
Shape: num_edges x num_embeddings
"""
# From https://github.com/jingraham/neurips19-graph-protein-design
d = edge_index[0] - edge_index[1]
frequency = torch.exp(
torch.arange(0, num_embeddings, 2, dtype=torch.float32, device=dev) * -(np.log(10000.0) / num_embeddings))
angles = d.unsqueeze(-1) * frequency
E = torch.cat((torch.cos(angles), torch.sin(angles)), -1)
return E
def _orientations(X_ca):
""" Compute forward and backward vectors per residue.
Args
----
X_ca : torch.FloatTensor
Tensor specifying atomic backbone coordinates for CA atoms.
Shape: num_res x 3
Returns
-------
torch.FloatTensor
Pairs of forward, backward vectors per residue.
Shape: num_res x 2 x 3
"""
# From https://github.com/drorlab/gvp-pytorch
forward = _normalize(X_ca[1:] - X_ca[:-1])
backward = _normalize(X_ca[:-1] - X_ca[1:])
forward = F.pad(forward, [0, 0, 0, 1])
backward = F.pad(backward, [0, 0, 1, 0])
return torch.cat([forward.unsqueeze(-2), backward.unsqueeze(-2)], -2)
def _sidechains(X):
""" Compute vectors pointing in the approximate direction of the sidechain.
Args
----
X : torch.FloatTensor
Tensor specifying atomic backbone coordinates.
Shape: num_res x 4 x 3
Returns
-------
vec : torch.FloatTensor
Sidechain vectors.
Shape: num_res x 3
"""
# From https://github.com/drorlab/gvp-pytorch
n, origin, c = X[:, 0], X[:, 1], X[:, 2]
c, n = _normalize(c - origin), _normalize(n - origin)
bisector = _normalize(c + n)
perp = _normalize(torch.cross(c, n))
vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3)
return vec
def _jing_featurize(protein, dev='cpu'):
""" Featurize individual proteins for use in torch_geometric Data objects,
as done in https://github.com/drorlab/gvp-pytorch
Args
----
protein : dict
Dictionary of protein features
- :code:`name` - PDB ID of the protein
- :code:`coords` - list of dicts specifying backbone atom coordinates
in the format of that outputted by :code:`parseCoords.py`
- :code:`seq` - protein sequence
- :code:`chain_idx` - an integer per residue such that each unique integer represents a unique chain
Returns
-------
torch_geometric.data.Data
Data object containing
- :code:`x` - CA atomic coordinates
- :code:`seq` - sequence of protein
- :code:`name` - PDB ID of protein
- :code:`node_s` - Node scalar features
- :code:`node_v` - Node vector features
- :code:`edge_s` - Edge scalar features
- :code:`edge_v` - Edge vector features
- :code:`edge_index` - Sparse representation of edge
- :code:`mask` - Residue mask specifying residues with incomplete coordinate sets
"""
name = protein['name']
with torch.no_grad():
coords = torch.as_tensor(protein['coords'], device=dev, dtype=torch.float32)
seq = torch.as_tensor(protein['seq'], device=dev, dtype=torch.long)
mask = torch.isfinite(coords.sum(dim=(1, 2)))
coords[~mask] = np.inf
X_ca = coords[:, 1]
edge_index = torch_cluster.knn_graph(X_ca, k=30, loop=True) # TODO: make param
pos_embeddings = _positional_embeddings(edge_index)
# generate mask for interchain interactions
pos_chain = (protein['chain_idx'][edge_index.view(-1)]).view(2, -1)
pos_mask = (pos_chain[0] != pos_chain[1])
# zero out all interchain positional embeddings
pos_embeddings = pos_mask.unsqueeze(-1) * pos_embeddings
E_vectors = X_ca[edge_index[0]] - X_ca[edge_index[1]]
rbf = _rbf(E_vectors.norm(dim=-1), D_count=16, device=dev) # TODO: make param
dihedrals = _dihedrals(coords)
orientations = _orientations(X_ca)
sidechains = _sidechains(coords)
node_s = dihedrals
node_v = torch.cat([orientations, sidechains.unsqueeze(-2)], dim=-2)
edge_s = torch.cat([rbf, pos_embeddings], dim=-1)
edge_v = _normalize(E_vectors).unsqueeze(-2)
node_s, node_v, edge_s, edge_v = map(torch.nan_to_num, (node_s, node_v, edge_s, edge_v))
data = torch_geometric.data.Data(x=X_ca,
seq=seq,
name=name,
node_s=node_s,
node_v=node_v,
edge_s=edge_s,
edge_v=edge_v,
edge_index=edge_index,
mask=mask)
return data
# Ingraham featurization functions
def _ingraham_featurize(batch, device="cpu"):
""" Pack and pad coords in batch into torch tensors
as done in https://github.com/jingraham/neurips19-graph-protein-design
Args
----
batch : list of dict
list of protein backbone coordinate dictionaries,
in the format of that outputted by :code:`parseCoords.py`
device : str
device to place torch tensors
Returns
-------
X : torch.Tensor
Batched coordinates tensor
mask : torch.Tensor
Mask for X
lengths : np.ndarray
Array of lengths of batched proteins
"""
B = len(batch)
lengths = np.array([b.shape[0] for b in batch], dtype=np.int32)
l_max = max(lengths)
X = np.zeros([B, l_max, 4, 3])
# Build the batch
for i, x in enumerate(batch):
l = x.shape[0]
x_pad = np.pad(x, [[0, l_max - l], [0, 0], [0, 0]], 'constant', constant_values=(np.nan, ))
X[i, :, :, :] = x_pad
# Mask
isnan = np.isnan(X)
mask = np.isfinite(np.sum(X, (2, 3))).astype(np.float32)
X[isnan] = 0.
# Conversion
X = torch.from_numpy(X).to(dtype=torch.float32, device=device)
mask = torch.from_numpy(mask).to(dtype=torch.float32, device=device)
return X, mask, lengths
def _quaternions(R, eps=1e-10):
""" Convert a batch of 3D rotations [R] to quaternions [Q]
R [...,3,3]
Q [...,4]
"""
def _R(i, j):
return R[..., i, j]
# Simple Wikipedia version
# en.wikipedia.org/wiki/Rotation_matrix#Quaternion
# For other options see math.stackexchange.com/questions/2074316/calculating-rotation-axis-from-rotation-matrix
diag = torch.diagonal(R, dim1=-2, dim2=-1)
Rxx, Ryy, Rzz = diag.unbind(-1)
magnitudes = 0.5 * torch.sqrt(
torch.abs(1 + torch.stack([Rxx - Ryy - Rzz, -Rxx + Ryy - Rzz, -Rxx - Ryy + Rzz], -1)) + eps)
signs = torch.sign(torch.stack([_R(2, 1) - _R(1, 2), _R(0, 2) - _R(2, 0), _R(1, 0) - _R(0, 1)], -1))
xyz = signs * magnitudes
# The relu enforces a non-negative trace
w = torch.sqrt(F.relu(1 + diag.sum(-1, keepdim=True))) / 2.
Q = torch.cat((xyz, w), -1)
Q = F.normalize(Q, dim=-1)
return Q
def _orientations_coarse(X, edge_index, eps=1e-6):
# Pair features
# Shifted slices of unit vectors
dX = X[1:, :] - X[:-1, :]
U = F.normalize(dX, dim=-1)
u_2 = U[:-2, :]
u_1 = U[1:-1, :]
u_0 = U[2:, :]
# Backbone normals
n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)
n_1 = F.normalize(torch.cross(u_1, u_0), dim=-1)
# Bond angle calculation
cosA = -(u_1 * u_0).sum(-1)
cosA = torch.clamp(cosA, -1 + eps, 1 - eps)
A = torch.acos(cosA)
# Angle between normals
cosD = (n_2 * n_1).sum(-1)
cosD = torch.clamp(cosD, -1 + eps, 1 - eps)
D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD)
# Backbone features
AD_features = torch.stack((torch.cos(A), torch.sin(A) * torch.cos(D), torch.sin(A) * torch.sin(D)), -1)
AD_features = F.pad(AD_features, (0, 0, 1, 2), 'constant', 0)
# Build relative orientations
o_1 = F.normalize(u_2 - u_1, dim=-1)
O = torch.stack((o_1, n_2, torch.cross(o_1, n_2)), 2)
O = O.view(list(O.shape[:1]) + [9])
O = F.pad(O, (0, 0, 1, 2), 'constant', 0)
# DEBUG: Viz [dense] pairwise orientations
# O = O.view(list(O.shape[:2]) + [3,3])
# dX = X.unsqueeze(2) - X.unsqueeze(1)
# dU = torch.matmul(O.unsqueeze(2), dX.unsqueeze(-1)).squeeze(-1)
# dU = dU / torch.norm(dU, dim=-1, keepdim=True)
# dU = (dU + 1.) / 2.
# plt.imshow(dU.data.numpy()[0])
# plt.show()
# print(dX.size(), O.size(), dU.size())
# exit(0)
O_pairs = O[edge_index]
X_pairs = X[edge_index]
# Re-view as rotation matrices
O_pairs = O_pairs.view(list(O_pairs.shape[:-1]) + [3,3])
# Rotate into local reference frames
dX = X_pairs[0] - X_pairs[1]
dU = torch.matmul(O_pairs[1], dX.unsqueeze(-1)).squeeze(-1)
dU = F.normalize(dU, dim=-1)
R = torch.matmul(O_pairs[1].transpose(-1, -2), O_pairs[0])
Q = _quaternions(R)
# Orientation features
O_features = torch.cat((dU, Q), dim=-1)
# DEBUG: Viz pairwise orientations
# IMG = Q[:,:,:,:3]
# # IMG = dU
# dU_full = torch.zeros(X.shape[0], X.shape[1], X.shape[1], 3).scatter(
# 2, E_idx.unsqueeze(-1).expand(-1,-1,-1,3), IMG
# )
# print(dU_full)
# dU_full = (dU_full + 1.) / 2.
# plt.imshow(dU_full.data.numpy()[0])
# plt.show()
# exit(0)
# print(Q.sum(), dU.sum(), R.sum())
return AD_features, O_features
def _ingraham_geometric_featurize(protein, dev='cpu'):
""" Featurize individual proteins for use in torch_geometric Data objects,
as done in https://github.com/drorlab/gvp-pytorch
Args
----
protein : dict
Dictionary of protein features
- :code:`name` - PDB ID of the protein
- :code:`coords` - list of dicts specifying backbone atom coordinates
in the format of that outputted by :code:`parseCoords.py`
- :code:`seq` - protein sequence
- :code:`chain_idx` - an integer per residue such that each unique integer represents a unique chain
Returns
-------
torch_geometric.data.Data
Data object containing
- :code:`x` - CA atomic coordinates
- :code:`seq` - sequence of protein
- :code:`name` - PDB ID of protein
- :code:`node_s` - Node scalar features
- :code:`node_v` - Node vector features
- :code:`edge_s` - Edge scalar features
- :code:`edge_v` - Edge vector features
- :code:`edge_index` - Sparse representation of edge
- :code:`mask` - Residue mask specifying residues with incomplete coordinate sets
"""
name = protein['name']
with torch.no_grad():
coords = torch.as_tensor(protein['coords'], device=dev, dtype=torch.float32)
seq = torch.as_tensor(protein['seq'], device=dev, dtype=torch.long)
mask = torch.isfinite(coords.sum(dim=(1, 2)))
coords[~mask] = np.inf
X_ca = coords[:, 1]
edge_index = torch_cluster.knn_graph(X_ca, k=30, loop=True) # TODO: make param
pos_embeddings = _positional_embeddings(edge_index)
# generate mask for interchain interactions
pos_chain = (protein['chain_idx'][edge_index.view(-1)]).view(2, -1)
pos_mask = (pos_chain[0] != pos_chain[1])
# zero out all interchain positional embeddings
pos_embeddings = pos_mask.unsqueeze(-1) * pos_embeddings
E_vectors = X_ca[edge_index[0]] - X_ca[edge_index[1]]
rbf = _rbf(E_vectors.norm(dim=-1), D_count=16, device=dev) # TODO: make param
dihedrals = _dihedrals(coords)
_, orientations = _orientations_coarse(X_ca, edge_index)
node_features = dihedrals
edge_features = torch.cat([pos_embeddings, rbf, orientations], dim=-1)
node_features, edge_features, = map(torch.nan_to_num, (node_features, edge_features))
data = torch_geometric.data.Data(x=X_ca,
seq=seq,
name=name,
node_features=node_features,
edge_features=edge_features,
edge_index=edge_index,
mask=mask)
return data
# Batching functions
[docs]def convert(tensor):
"""Converts given tensor from numpy to pytorch."""
return torch.from_numpy(tensor)
def _package(b_idx):
"""Package the given datapoints into tensors based on provided indices.
Tensors are extracted from the data and padded. Coordinates are featurized
and the length of TERMs and chain IDs are added to the data.
Args
----
b_idx : list of tuples (dicts, int)
The feature dictionaries, as well as an int for the sum of the lengths of all TERMs,
for each datapoint to package.
Returns
-------
dict
Collection of batched features required for running TERMinator. This contains:
- :code:`msas` - the sequences for each TERM match to the target structure
- :code:`features` - the :math:`\\phi, \\psi, \\omega`, and environment values of the TERM matches
- :code:`ppoe` - the :math:`\\phi, \\psi, \\omega`, and environment values of the target structure
- :code:`seq_lens` - lengths of the target sequences
- :code:`focuses` - the corresponding target structure residue index for each TERM residue
- :code:`contact_idxs` - contact indices for each TERM residue
- :code:`src_key_mask` - mask for TERM residue padding
- :code:`X` - coordinates
- :code:`x_mask` - mask for the target structure
- :code:`seqs` - the target sequences
- :code:`ids` - the PDB ids
- :code:`chain_idx` - the chain IDs
"""
# wrap up all the tensors with proper padding and masks
batch = [data[0] for data in b_idx]
focus_lens = [data[1] for data in b_idx]
features, msas, focuses, seq_lens, coords = [], [], [], [], []
term_lens = []
seqs = []
ids = []
chain_lens = []
ppoe = []
contact_idxs = []
gvp_data = []
geometric_data = []
sortcery_seqs = []
sortcery_nrgs = []
for _, data in enumerate(batch):
# have to transpose these two because then we can use pad_sequence for padding
features.append(convert(data['features']).transpose(0, 1))
msas.append(convert(data['msas']).transpose(0, 1))
ppoe.append(convert(data['ppoe']))
focuses.append(convert(data['focuses']))
contact_idxs.append(convert(data['contact_idxs']))
seq_lens.append(data['seq_len'])
term_lens.append(data['term_lens'].tolist())
coords.append(data['coords'])
seqs.append(convert(data['sequence']))
ids.append(data['pdb'])
chain_lens.append(data['chain_lens'])
if 'sortcery_seqs' in data:
assert len(batch) == 1, "batch_size for SORTCERY fine-tuning should be set to 1"
sortcery_seqs = convert(data['sortcery_seqs']).unsqueeze(0)
if 'sortcery_nrgs' in data:
sortcery_nrgs = convert(data['sortcery_nrgs']).unsqueeze(0)
chain_idx = []
for i, c_len in enumerate(data['chain_lens']):
chain_idx.append(torch.ones(c_len) * i)
chain_idx = torch.cat(chain_idx, dim=0)
gvp_data.append(
_jing_featurize({
'name': data['pdb'],
'coords': data['coords'],
'seq': data['sequence'],
'chain_idx': chain_idx
}))
geometric_data.append(
_ingraham_geometric_featurize({
'name': data['pdb'],
'coords': data['coords'],
'seq': data['sequence'],
'chain_idx': chain_idx
}))
# transpose back after padding
features = pad_sequence(features, batch_first=True).transpose(1, 2)
msas = pad_sequence(msas, batch_first=True).transpose(1, 2).long()
# we can pad these using standard pad_sequence
ppoe = pad_sequence(ppoe, batch_first=True)
focuses = pad_sequence(focuses, batch_first=True)
contact_idxs = pad_sequence(contact_idxs, batch_first=True)
src_key_mask = pad_sequence([torch.zeros(l) for l in focus_lens], batch_first=True, padding_value=1).bool()
seqs = pad_sequence(seqs, batch_first=True)
# we do some padding so that tensor reshaping during batchifyTERM works
# TODO(alex): explain this since I have no idea what's going on
max_aa = focuses.size(-1)
for lens in term_lens:
max_term_len = max(lens)
diff = max_aa - sum(lens)
lens += [max_term_len] * (diff // max_term_len)
lens.append(diff % max_term_len)
# featurize coordinates same way as ingraham et al
X, x_mask, _ = _ingraham_featurize(coords)
# pad with -1 so we can store term_lens in a tensor
seq_lens = torch.tensor(seq_lens)
max_all_term_lens = max([len(term) for term in term_lens])
for i, _ in enumerate(term_lens):
term_lens[i] += [-1] * (max_all_term_lens - len(term_lens[i]))
term_lens = torch.tensor(term_lens)
# generate chain_idx from chain_lens
chain_idx = []
for c_lens in chain_lens:
arrs = []
for i, chain_len in enumerate(c_lens):
arrs.append(torch.ones(chain_len) * i)
chain_idx.append(torch.cat(arrs, dim=-1))
chain_idx = pad_sequence(chain_idx, batch_first=True)
return {
'msas': msas,
'features': features.float(),
'ppoe': ppoe.float(),
'seq_lens': seq_lens,
'focuses': focuses,
'contact_idxs': contact_idxs,
'src_key_mask': src_key_mask,
'term_lens': term_lens,
'X': X,
'x_mask': x_mask,
'seqs': seqs,
'ids': ids,
'chain_idx': chain_idx,
'gvp_data': gvp_data,
'sortcery_seqs': sortcery_seqs,
'sortcery_nrgs': sortcery_nrgs,
'geometric_data': geometric_data
}
# Non-lazy data loading functions
[docs]def load_file(in_folder, pdb_id, min_protein_len=30):
"""Load the data specified in the proper .features file and return them.
If the read sequence length is less than :code:`min_protein_len`, instead return None.
Args
----
in_folder : str
folder to find TERM file.
pdb_id : str
PDB ID to load.
min_protein_len : int
minimum cutoff for loading TERM file.
Returns
-------
data : dict
Data from TERM file (as dict)
total_term_len : int
Sum of lengths of all TERMs
seq_len : int
Length of protein sequence
"""
path = f"{in_folder}/{pdb_id}/{pdb_id}.features"
with open(path, 'rb') as fp:
data = pickle.load(fp)
seq_len = data['seq_len']
total_term_length = data['term_lens'].sum()
if seq_len < min_protein_len:
return None
return data, total_term_length, seq_len
[docs]class TERMDataset(Dataset):
"""TERM Dataset that loads all feature files into a Pytorch Dataset-like structure.
Attributes
----
dataset : list
list of tuples containing features, TERM length, and sequence length
shuffle_idx : list
array of indices for the dataset, for shuffling
"""
[docs] def __init__(self, in_folder, pdb_ids=None, min_protein_len=30, num_processes=32):
"""
Initializes current TERM dataset by reading in feature files.
Reads in all feature files from the given directory, using multiprocessing
with the provided number of processes. Stores the features, the TERM length,
and the sequence length as a tuple representing the data. Can read from PDB ids or
file paths directly. Uses the given protein length as a cutoff.
Args
----
in_folder : str
path to directory containing feature files generated by :code:`scripts/data/preprocessing/generateDataset.py`
pdb_ids: list, optional
list of pdbs from `in_folder` to include in the dataset
min_protein_len: int, default=30
minimum length of a protein in the dataset
num_processes: int, default=32
number of processes to use during dataloading
"""
self.dataset = []
with mp.Pool(num_processes) as pool:
if pdb_ids:
print("Loading feature files")
progress = tqdm(total=len(pdb_ids))
def update_progress(res):
del res
progress.update(1)
res_list = [
pool.apply_async(load_file, (in_folder, id),
kwds={"min_protein_len": min_protein_len},
callback=update_progress) for id in pdb_ids
]
pool.close()
pool.join()
progress.close()
for res in res_list:
data = res.get()
if data is not None:
features, total_term_length, seq_len = data
self.dataset.append((features, total_term_length, seq_len))
else:
print("Loading feature file paths")
filelist = list(glob.glob(f'{in_folder}/*/*.features'))
progress = tqdm(total=len(filelist))
def update_progress(res):
del res
progress.update(1)
# get pdb_ids
pdb_ids = [os.path.basename(path)[:-len(".features")] for path in filelist]
res_list = [
pool.apply_async(load_file, (in_folder, id),
kwds={"min_protein_len": min_protein_len},
callback=update_progress) for id in pdb_ids
]
pool.close()
pool.join()
progress.close()
for res in res_list:
data = res.get()
if data is not None:
features, total_term_length, seq_len = data
self.dataset.append((features, total_term_length, seq_len))
self.shuffle_idx = np.arange(len(self.dataset))
[docs] def shuffle(self):
"""Shuffle the current dataset."""
np.random.shuffle(self.shuffle_idx)
def __len__(self):
"""Returns length of the given dataset.
Returns
-------
int
length of dataset
"""
return len(self.dataset)
def __getitem__(self, idx):
"""Extract a given item with provided index.
Args
----
idx : int
Index of item to return.
Returns
----
data : dict
Data from TERM file (as dict)
total_term_len : int
Sum of lengths of all TERMs
seq_len : int
Length of protein sequence
"""
data_idx = self.shuffle_idx[idx]
if isinstance(data_idx, list):
return [self.dataset[i] for i in data_idx]
return self.dataset[data_idx]
[docs]class TERMBatchSampler(Sampler):
"""BatchSampler/Dataloader helper class for TERM data using TERMDataset.
Attributes
----
size: int
Length of the dataset
dataset: List
List of features from TERM dataset
total_term_lengths: List
List of TERM lengths from the given dataset
seq_lengths: List
List of sequence lengths from the given dataset
lengths: List
TERM lengths or sequence lengths, depending on
whether :code:`max_term_res` or :code:`max_seq_tokens` is set.
batch_size : int or None, default=4
Size of batches created. If variable sized batches are desired, set to None.
sort_data : bool, default=False
Create deterministic batches by sorting the data according to the
specified length metric and creating batches from the sorted data.
Incompatible with :code:`shuffle=True` and :code:`semi_shuffle=True`.
shuffle : bool, default=True
Shuffle the data completely before creating batches.
Incompatible with :code:`sort_data=True` and :code:`semi_shuffle=True`.
semi_shuffle : bool, default=False
Sort the data according to the specified length metric,
then partition the data into :code:`semi_shuffle_cluster_size`-sized partitions.
Within each partition perform a complete shuffle. The upside is that
batching with similar lengths reduces padding making for more efficient computation,
but the downside is that it does a less complete shuffle.
semi_shuffle_cluster_size : int, default=500
Size of partition to use when :code:`semi_shuffle=True`.
batch_shuffle : bool, default=True
If set to :code:`True`, shuffle samples within a batch.
drop_last : bool, default=False
If set to :code:`True`, drop the last samples if they don't form a complete batch.
max_term_res : int or None, default=55000
When :code:`batch_size=None, max_term_res>0, max_seq_tokens=None`,
batch by fitting as many datapoints as possible with the total number of
TERM residues included below `max_term_res`.
Calibrated using :code:`nn.DataParallel` on two V100 GPUs.
max_seq_tokens : int or None, default=None
When :code:`batch_size=None, max_term_res=None, max_seq_tokens>0`,
batch by fitting as many datapoints as possible with the total number of
sequence residues included below `max_seq_tokens`.
"""
[docs] def __init__(self,
dataset,
batch_size=4,
sort_data=False,
shuffle=True,
semi_shuffle=False,
semi_shuffle_cluster_size=500,
batch_shuffle=True,
drop_last=False,
max_term_res=55000,
max_seq_tokens=None):
"""
Reads in and processes a given dataset.
Given the provided dataset, load all the data. Then cluster the data using
the provided method, either shuffled or sorted and then shuffled.
Args
----
dataset : TERMDataset
Dataset to batch.
batch_size : int or None, default=4
Size of batches created. If variable sized batches are desired, set to None.
sort_data : bool, default=False
Create deterministic batches by sorting the data according to the
specified length metric and creating batches from the sorted data.
Incompatible with :code:`shuffle=True` and :code:`semi_shuffle=True`.
shuffle : bool, default=True
Shuffle the data completely before creating batches.
Incompatible with :code:`sort_data=True` and :code:`semi_shuffle=True`.
semi_shuffle : bool, default=False
Sort the data according to the specified length metric,
then partition the data into :code:`semi_shuffle_cluster_size`-sized partitions.
Within each partition perform a complete shuffle. The upside is that
batching with similar lengths reduces padding making for more efficient computation,
but the downside is that it does a less complete shuffle.
semi_shuffle_cluster_size : int, default=500
Size of partition to use when :code:`semi_shuffle=True`.
batch_shuffle : bool, default=True
If set to :code:`True`, shuffle samples within a batch.
drop_last : bool, default=False
If set to :code:`True`, drop the last samples if they don't form a complete batch.
max_term_res : int or None, default=55000
When :code:`batch_size=None, max_term_res>0, max_seq_tokens=None`,
batch by fitting as many datapoints as possible with the total number of
TERM residues included below `max_term_res`.
Calibrated using :code:`nn.DataParallel` on two V100 GPUs.
max_seq_tokens : int or None, default=None
When :code:`batch_size=None, max_term_res=None, max_seq_tokens>0`,
batch by fitting as many datapoints as possible with the total number of
sequence residues included below `max_seq_tokens`. Exactly one of :code:`max_term_res`
and :code:`max_seq_tokens` must be None.
"""
super().__init__(dataset)
self.size = len(dataset)
self.dataset, self.total_term_lengths, self.seq_lengths = zip(*dataset)
assert not (max_term_res is None
and max_seq_tokens is None), "Exactly one of max_term_res and max_seq_tokens must be None"
if max_term_res is None and max_seq_tokens > 0:
self.lengths = self.seq_lengths
elif max_term_res > 0 and max_seq_tokens is None:
self.lengths = self.total_term_lengths
else:
raise ValueError("Exactly one of max_term_res and max_seq_tokens must be None")
self.shuffle = shuffle
self.sort_data = sort_data
self.batch_shuffle = batch_shuffle
self.batch_size = batch_size
self.drop_last = drop_last
self.max_term_res = max_term_res
self.max_seq_tokens = max_seq_tokens
self.semi_shuffle = semi_shuffle
self.semi_shuffle_cluster_size = semi_shuffle_cluster_size
assert not (shuffle and semi_shuffle), "Lazy Dataloader shuffle and semi shuffle cannot both be set"
# initialize clusters
self._cluster()
[docs] def _cluster(self):
""" Shuffle data and make clusters of indices corresponding to batches of data.
This method speeds up training by sorting data points with similar TERM lengths
together, if :code:`sort_data` or :code:`semi_shuffle` are on. Under `sort_data`,
the data is sorted by length. Under `semi_shuffle`, the data is broken up
into clusters based on length and shuffled within the clusters. Otherwise,
it is randomly shuffled. Data is then loaded into batches based on the number
of proteins that will fit into the GPU without overloading it, based on
:code:`max_term_res` or :code:`max_seq_tokens`.
"""
# if we sort data, use sorted indexes instead
if self.sort_data:
idx_list = np.argsort(self.lengths)
elif self.semi_shuffle:
# trying to speed up training
# by shuffling points with similar term res together
idx_list = np.argsort(self.lengths)
shuffle_borders = []
# break up datapoints into large clusters
border = 0
while border < len(self.lengths):
shuffle_borders.append(border)
border += self.semi_shuffle_cluster_size
# shuffle datapoints within clusters
last_cluster_idx = len(shuffle_borders) - 1
for cluster_idx in range(last_cluster_idx + 1):
start = shuffle_borders[cluster_idx]
if cluster_idx < last_cluster_idx:
end = shuffle_borders[cluster_idx + 1]
np.random.shuffle(idx_list[start:end])
else:
np.random.shuffle(idx_list[start:])
else:
idx_list = list(range(len(self.dataset)))
np.random.shuffle(idx_list)
# Cluster into batches of similar sizes
clusters, batch = [], []
# if batch_size is None, fit as many proteins we can into a batch
# without overloading the GPU
if self.batch_size is None:
if self.max_term_res is None and self.max_seq_tokens > 0:
cap_len = self.max_seq_tokens
elif self.max_term_res > 0 and self.max_seq_tokens is None:
cap_len = self.max_term_res
current_batch_lens = []
total_data_len = 0
for count, idx in enumerate(idx_list):
current_batch_lens.append(self.lengths[idx])
total_data_len = max(current_batch_lens) * len(current_batch_lens)
if count != 0 and total_data_len > cap_len:
clusters.append(batch)
batch = [idx]
current_batch_lens = [self.lengths[idx]]
else:
batch.append(idx)
else: # used fixed batch size
for count, idx in enumerate(idx_list):
if count != 0 and count % self.batch_size == 0:
clusters.append(batch)
batch = [idx]
else:
batch.append(idx)
if len(batch) > 0 and not self.drop_last:
clusters.append(batch)
self.clusters = clusters
[docs] def package(self, b_idx):
"""Package the given datapoints into tensors based on provided indices.
Tensors are extracted from the data and padded. Coordinates are featurized
and the length of TERMs and chain IDs are added to the data.
Args
----
b_idx : list of tuples (dicts, int, int)
The feature dictionaries, the sum of the lengths of all TERMs, and the sum of all sequence lengths
for each datapoint to package.
Returns
-------
dict
Collection of batched features required for running TERMinator. This contains:
- :code:`msas` - the sequences for each TERM match to the target structure
- :code:`features` - the :math:`\\phi, \\psi, \\omega`, and environment values of the TERM matches
- :code:`ppoe` - the :math:`\\phi, \\psi, \\omega`, and environment values of the target structure
- :code:`seq_lens` - lengths of the target sequences
- :code:`focuses` - the corresponding target structure residue index for each TERM residue
- :code:`contact_idxs` - contact indices for each TERM residue
- :code:`src_key_mask` - mask for TERM residue padding
- :code:`X` - coordinates
- :code:`x_mask` - mask for the target structure
- :code:`seqs` - the target sequences
- :code:`ids` - the PDB ids
- :code:`chain_idx` - the chain IDs
"""
return _package([b[0:2] for b in b_idx])
def __len__(self):
"""Returns length of dataset, i.e. number of batches.
Returns
-------
int
length of dataset.
"""
return len(self.clusters)
def __iter__(self):
"""Allows iteration over dataset."""
if self.shuffle or self.semi_shuffle:
self._cluster()
np.random.shuffle(self.clusters)
for batch in self.clusters:
yield batch
# needs to be outside of object for pickling reasons (?)
[docs]def read_lens(in_folder, pdb_id, min_protein_len=30):
""" Reads the lengths specified in the proper .length file and return them.
If the read sequence length is less than :code:`min_protein_len`, instead return None.
Args
----
in_folder : str
folder to find TERM file.
pdb_id : str
PDB ID to load.
min_protein_len : int
minimum cutoff for loading TERM file.
Returns
-------
pdb_id : str
PDB ID that was loaded
total_term_length : int
number of TERMS in file
seq_len : int
sequence length of file, or None if sequence length is less than :code:`min_protein_len`
"""
path = f"{in_folder}/{pdb_id}/{pdb_id}.length"
# pylint: disable=unspecified-encoding
with open(path, 'rt') as fp:
total_term_length = int(fp.readline().strip())
seq_len = int(fp.readline().strip())
if seq_len < min_protein_len:
return None
return pdb_id, total_term_length, seq_len
[docs]class TERMLazyDataset(Dataset):
"""TERM Dataset that loads all feature files into a Pytorch Dataset-like structure.
Unlike TERMDataset, this just loads feature filenames, not actual features.
Attributes
----
dataset : list
list of tuples containing feature filenames, TERM length, and sequence length
shuffle_idx : list
array of indices for the dataset, for shuffling
"""
[docs] def __init__(self, in_folder, pdb_ids=None, min_protein_len=30, num_processes=32):
"""
Initializes current TERM dataset by reading in feature files.
Reads in all feature files from the given directory, using multiprocessing
with the provided number of processes. Stores the feature filenames, the TERM length,
and the sequence length as a tuple representing the data. Can read from PDB ids or
file paths directly. Uses the given protein length as a cutoff.
Args
----
in_folder : str
path to directory containing feature files generated by :code:`scripts/data/preprocessing/generateDataset.py`
pdb_ids: list, optional
list of pdbs from `in_folder` to include in the dataset
min_protein_len: int, default=30
minimum length of a protein in the dataset
num_processes: int, default=32
number of processes to use during dataloading
"""
self.dataset = []
with mp.Pool(num_processes) as pool:
if pdb_ids:
print("Loading feature file paths")
progress = tqdm(total=len(pdb_ids))
def update_progress(res):
del res
progress.update(1)
res_list = [
pool.apply_async(read_lens, (in_folder, pdb_id),
kwds={"min_protein_len": min_protein_len},
callback=update_progress) for pdb_id in pdb_ids
]
pool.close()
pool.join()
progress.close()
for res in res_list:
data = res.get()
if data is not None:
pdb_id, total_term_length, seq_len = data
filename = f"{in_folder}/{pdb_id}/{pdb_id}.features"
self.dataset.append((os.path.abspath(filename), total_term_length, seq_len))
else:
print("Loading feature file paths")
filelist = list(glob.glob(f'{in_folder}/*/*.features'))
progress = tqdm(total=len(filelist))
def update_progress(res):
del res
progress.update(1)
# get pdb_ids
pdb_ids = [os.path.basename(path)[:-len(".features")] for path in filelist]
res_list = [
pool.apply_async(read_lens, (in_folder, pdb_id),
kwds={"min_protein_len": min_protein_len},
callback=update_progress) for pdb_id in pdb_ids
]
pool.close()
pool.join()
progress.close()
for res in res_list:
data = res.get()
if data is not None:
pdb_id, total_term_length, seq_len = data
filename = f"{in_folder}/{pdb_id}/{pdb_id}.features"
self.dataset.append((os.path.abspath(filename), total_term_length, seq_len))
self.shuffle_idx = np.arange(len(self.dataset))
[docs] def shuffle(self):
"""Shuffle the dataset"""
np.random.shuffle(self.shuffle_idx)
def __len__(self):
"""Returns length of the given dataset.
Returns
-------
int
length of dataset
"""
return len(self.dataset)
def __getitem__(self, idx):
"""Extract a given item with provided index.
Args
----
idx : int
Index of item to return.
Returns
----
data : dict
Data from TERM file (as dict)
total_term_len : int
Sum of lengths of all TERMs
seq_len : int
Length of protein sequence
"""
data_idx = self.shuffle_idx[idx]
if isinstance(data_idx, list):
return [self.dataset[i] for i in data_idx]
return self.dataset[data_idx]
[docs]class TERMLazyBatchSampler(Sampler):
"""BatchSampler/Dataloader helper class for TERM data using TERMLazyDataset.
Attributes
----------
dataset : TERMLazyDataset
Dataset to batch.
size : int
Length of dataset
batch_size : int or None, default=4
Size of batches created. If variable sized batches are desired, set to None.
sort_data : bool, default=False
Create deterministic batches by sorting the data according to the
specified length metric and creating batches from the sorted data.
Incompatible with :code:`shuffle=True` and :code:`semi_shuffle=True`.
shuffle : bool, default=True
Shuffle the data completely before creating batches.
Incompatible with :code:`sort_data=True` and :code:`semi_shuffle=True`.
semi_shuffle : bool, default=False
Sort the data according to the specified length metric,
then partition the data into :code:`semi_shuffle_cluster_size`-sized partitions.
Within each partition perform a complete shuffle. The upside is that
batching with similar lengths reduces padding making for more efficient computation,
but the downside is that it does a less complete shuffle.
semi_shuffle_cluster_size : int, default=500
Size of partition to use when :code:`semi_shuffle=True`.
batch_shuffle : bool, default=True
If set to :code:`True`, shuffle samples within a batch.
drop_last : bool, default=False
If set to :code:`True`, drop the last samples if they don't form a complete batch.
max_term_res : int or None, default=55000
When :code:`batch_size=None, max_term_res>0, max_seq_tokens=None`,
batch by fitting as many datapoints as possible with the total number of
TERM residues included below `max_term_res`.
Calibrated using :code:`nn.DataParallel` on two V100 GPUs.
max_seq_tokens : int or None, default=None
When :code:`batch_size=None, max_term_res=None, max_seq_tokens>0`,
batch by fitting as many datapoints as possible with the total number of
sequence residues included below `max_seq_tokens`.
term_matches_cutoff : int or None, default=None
Use the top :code:`term_matches_cutoff` TERM matches for featurization.
If :code:`None`, apply no cutoff.
term_dropout : str or None, default=None
Let `t` be the number of TERM matches in the given datapoint.
Select a random int `n` from 1 to `t`, and take a random subset `n`
of the given TERM matches to keep. If :code:`term_dropout='keep_first'`,
keep the first match and choose `n-1` from the rest.
If :code:`term_dropout='all'`, choose `n` matches from all matches.
"""
[docs] def __init__(self,
dataset,
batch_size=4,
sort_data=False,
shuffle=True,
semi_shuffle=False,
semi_shuffle_cluster_size=500,
batch_shuffle=True,
drop_last=False,
max_term_res=55000,
max_seq_tokens=None,
term_matches_cutoff=None,
term_dropout=None):
"""
Reads in and processes a given dataset.
Given the provided dataset, load all the data. Then cluster the data using
the provided method, either shuffled or sorted and then shuffled.
Args
----
dataset : TERMLazyDataset
Dataset to batch.
batch_size : int or None, default=4
Size of batches created. If variable sized batches are desired, set to None.
sort_data : bool, default=False
Create deterministic batches by sorting the data according to the
specified length metric and creating batches from the sorted data.
Incompatible with :code:`shuffle=True` and :code:`semi_shuffle=True`.
shuffle : bool, default=True
Shuffle the data completely before creating batches.
Incompatible with :code:`sort_data=True` and :code:`semi_shuffle=True`.
semi_shuffle : bool, default=False
Sort the data according to the specified length metric,
then partition the data into :code:`semi_shuffle_cluster_size`-sized partitions.
Within each partition perform a complete shuffle. The upside is that
batching with similar lengths reduces padding making for more efficient computation,
but the downside is that it does a less complete shuffle.
semi_shuffle_cluster_size : int, default=500
Size of partition to use when :code:`semi_shuffle=True`.
batch_shuffle : bool, default=True
If set to :code:`True`, shuffle samples within a batch.
drop_last : bool, default=False
If set to :code:`True`, drop the last samples if they don't form a complete batch.
max_term_res : int or None, default=55000
When :code:`batch_size=None, max_term_res>0, max_seq_tokens=None`,
batch by fitting as many datapoints as possible with the total number of
TERM residues included below `max_term_res`.
Calibrated using :code:`nn.DataParallel` on two V100 GPUs.
max_seq_tokens : int or None, default=None
When :code:`batch_size=None, max_term_res=None, max_seq_tokens>0`,
batch by fitting as many datapoints as possible with the total number of
sequence residues included below `max_seq_tokens`.
term_matches_cutoff : int or None, default=None
Use the top :code:`term_matches_cutoff` TERM matches for featurization.
If :code:`None`, apply no cutoff.
term_dropout : str or None, default=None
Let `t` be the number of TERM matches in the given datapoint.
Select a random int `n` from 1 to `t`, and take a random subset `n`
of the given TERM matches to keep. If :code:`term_dropout='keep_first'`,
keep the first match and choose `n-1` from the rest.
If :code:`term_dropout='all'`, choose `n` matches from all matches.
"""
super().__init__(dataset)
self.dataset = dataset
self.size = len(dataset)
self.filepaths, self.total_term_lengths, self.seq_lengths = zip(*dataset)
assert not (max_term_res is None
and max_seq_tokens is None), "Exactly one of max_term_res and max_seq_tokens must be None"
if max_term_res is None and max_seq_tokens > 0:
self.lengths = self.seq_lengths
elif max_term_res > 0 and max_seq_tokens is None:
self.lengths = self.total_term_lengths
else:
raise Exception("Exactly one of max_term_res and max_seq_tokens must be None")
self.shuffle = shuffle
self.sort_data = sort_data
self.batch_shuffle = batch_shuffle
self.batch_size = batch_size
self.drop_last = drop_last
self.max_term_res = max_term_res
self.max_seq_tokens = max_seq_tokens
self.semi_shuffle = semi_shuffle
self.semi_shuffle_cluster_size = semi_shuffle_cluster_size
self.term_matches_cutoff = term_matches_cutoff
assert term_dropout in ["keep_first", "all", None], f"term_dropout={term_dropout} is not a valid argument"
self.term_dropout = term_dropout
assert not (shuffle and semi_shuffle), "Lazy Dataloader shuffle and semi shuffle cannot both be set"
# initialize clusters
self._cluster()
[docs] def _cluster(self):
""" Shuffle data and make clusters of indices corresponding to batches of data.
This method speeds up training by sorting data points with similar TERM lengths
together, if :code:`sort_data` or :code:`semi_shuffle` are on. Under `sort_data`,
the data is sorted by length. Under `semi_shuffle`, the data is broken up
into clusters based on length and shuffled within the clusters. Otherwise,
it is randomly shuffled. Data is then loaded into batches based on the number
of proteins that will fit into the GPU without overloading it, based on
:code:`max_term_res` or :code:`max_seq_tokens`.
"""
# if we sort data, use sorted indexes instead
if self.sort_data:
idx_list = np.argsort(self.lengths)
elif self.semi_shuffle:
# trying to speed up training
# by shuffling points with similar term res together
idx_list = np.argsort(self.lengths)
shuffle_borders = []
# break up datapoints into large clusters
border = 0
while border < len(self.lengths):
shuffle_borders.append(border)
border += self.semi_shuffle_cluster_size
# shuffle datapoints within clusters
last_cluster_idx = len(shuffle_borders) - 1
for cluster_idx in range(last_cluster_idx + 1):
start = shuffle_borders[cluster_idx]
if cluster_idx < last_cluster_idx:
end = shuffle_borders[cluster_idx + 1]
np.random.shuffle(idx_list[start:end])
else:
np.random.shuffle(idx_list[start:])
else:
idx_list = list(range(len(self.dataset)))
np.random.shuffle(idx_list)
# Cluster into batches of similar sizes
clusters, batch = [], []
# if batch_size is None, fit as many proteins we can into a batch
# without overloading the GPU
if self.batch_size is None:
if self.max_term_res is None and self.max_seq_tokens > 0:
cap_len = self.max_seq_tokens
elif self.max_term_res > 0 and self.max_seq_tokens is None:
cap_len = self.max_term_res
current_batch_lens = []
total_data_len = 0
for count, idx in enumerate(idx_list):
current_batch_lens.append(self.lengths[idx])
total_data_len = max(current_batch_lens) * len(current_batch_lens)
if count != 0 and total_data_len > cap_len:
clusters.append(batch)
batch = [idx]
current_batch_lens = [self.lengths[idx]]
else:
batch.append(idx)
else: # used fixed batch size
for count, idx in enumerate(idx_list):
if count != 0 and count % self.batch_size == 0:
clusters.append(batch)
batch = [idx]
else:
batch.append(idx)
if len(batch) > 0 and not self.drop_last:
clusters.append(batch)
self.clusters = clusters
[docs] def package(self, b_idx):
"""Package the given datapoints into tensors based on provided indices.
Tensors are extracted from the data and padded. Coordinates are featurized
and the length of TERMs and chain IDs are added to the data.
Args
----
b_idx : list of (str, int, int)
The path to the feature files, the sum of the lengths of all TERMs, and the sum of all sequence lengths
for each datapoint to package.
Returns
-------
dict
Collection of batched features required for running TERMinator. This contains:
- :code:`msas` - the sequences for each TERM match to the target structure
- :code:`features` - the :math:`\\phi, \\psi, \\omega`, and environment values of the TERM matches
- :code:`ppoe` - the :math:`\\phi, \\psi, \\omega`, and environment values of the target structure
- :code:`seq_lens` - lengths of the target sequences
- :code:`focuses` - the corresponding target structure residue index for each TERM residue
- :code:`contact_idxs` - contact indices for each TERM residue
- :code:`src_key_mask` - mask for TERM residue padding
- :code:`X` - coordinates
- :code:`x_mask` - mask for the target structure
- :code:`seqs` - the target sequences
- :code:`ids` - the PDB ids
- :code:`chain_idx` - the chain IDs
"""
if self.batch_shuffle:
b_idx_copy = b_idx[:]
random.shuffle(b_idx_copy)
b_idx = b_idx_copy
# load the files specified by filepaths
batch = []
for data in b_idx:
filepath = data[0]
with open(filepath, 'rb') as fp:
batch.append((pickle.load(fp), data[1]))
if 'ppoe' not in batch[-1][0].keys():
print(filepath)
# package batch
packaged_batch = _package(batch)
features = packaged_batch["features"]
msas = packaged_batch["msas"]
# apply TERM matches cutoff
if self.term_matches_cutoff:
features = features[:, :self.term_matches_cutoff]
msas = msas[:, :self.term_matches_cutoff]
# apply TERM matches dropout
if self.term_dropout:
# sample a random number of alignments to keep
n_batch, n_align, n_terms, n_features = features.shape
if self.term_dropout == 'keep_first':
n_keep = torch.randint(0, n_align, [1]).item()
elif self.term_dropout == 'all':
n_keep = torch.randint(1, n_align, [1]).item()
# sample from a multinomial distribution
weights = torch.ones([1, 1]).expand([n_batch * n_terms, n_keep])
if n_keep == 0:
sample_idx = torch.ones(1)
else:
sample_idx = torch.multinomial(weights, n_keep)
sample_idx = sample_idx.view([n_batch, n_terms, n_keep]).transpose(-1, -2)
sample_idx_features = sample_idx.unsqueeze(-1).expand([n_batch, n_keep, n_terms, n_features])
sample_idx_msas = sample_idx
if self.term_dropout == 'keep_first':
if n_keep == 0:
features = features[:, 0:1]
msas = msas[:, 0:1]
else:
sample_features = torch.gather(features[:, 1:], 1, sample_idx_features)
sample_msas = torch.gather(msas[:, 1:], 1, sample_idx_msas)
features = torch.cat([features[:, 0:1], sample_features], dim=1)
msas = torch.cat([msas[:, 0:1], sample_msas], dim=1)
elif self.term_dropout == 'all':
features = torch.gather(features, 1, sample_idx_features)
msas = torch.gather(msas, 1, sample_idx_msas)
packaged_batch["features"] = features
packaged_batch["msas"] = msas
return packaged_batch
def __len__(self):
"""Returns length of dataset, i.e. number of batches.
Returns
-------
int
length of dataset.
"""
return len(self.clusters)
def __iter__(self):
"""Allows iteration over dataset."""
if self.shuffle or self.semi_shuffle:
self._cluster()
np.random.shuffle(self.clusters)
for batch in self.clusters:
yield batch