Source code for terminator.utils.model.loop_utils

""" Utilities for running training and evaluation loops """
import torch
from tqdm import tqdm

# pylint: disable=no-member


def _to_dev(data_dict, dev):
    """ Push all tensor objects in the dictionary to the given device.

    Args
    ----
    data_dict : dict
        Dictionary of input features to TERMinator
    dev : str
        Device to load tensors onto
    """
    for key, value in data_dict.items():
        if isinstance(value, torch.Tensor):
            data_dict[key] = value.to(dev)
        if key == 'gvp_data':
            data_dict['gvp_data'] = [data.to(dev) for data in data_dict['gvp_data']]


def _ld_item_values(ld):
    """ Convert all 0-dim tensors in a loss dictionary into python native types.

    Args
    ----
    ld : dict
        Dictionary with keys :code:`loss_fn_name` and values of dictionaries with
        - :code:`loss` corresponding to the loss value
        - :code:`count` corresponding to the normalizing factor
        - :code:`scaling_factor` corresponding to the scaling coefficient in the loss function

    Returns
    -------
    ld_copy
        Updated loss dictionary with all 0-dim tensors converted to python native types.
    """
    ld_copy = ld.copy()
    for sub_ld in ld_copy.values():
        for key, val in sub_ld.items():
            if isinstance(val, torch.Tensor):
                if len(val.shape) == 0:
                    sub_ld[key] = val.item()
                else:
                    raise RuntimeError("loss dictionary contains non-0-dim tensor value")
    return ld_copy


def _compute_loss(loss_dict):
    """ Compute the total loss given a loss dictionary

    Args
    ----
    loss_dict : dict
        Dictionary with keys :code:`loss_fn_name` and values of dictionaries with
        - :code:`loss` corresponding to the loss value
        - :code:`count` corresponding to the normalizing factor
        - :code:`scaling_factor` corresponding to the scaling coefficient in the loss function

    Returns
    -------
    loss : torch.Tensor of size 1
        Computed loss
    """
    loss = 0
    for subloss_dict in loss_dict.values():
        loss += subloss_dict["loss"] * subloss_dict["scaling_factor"]
    return loss


def _sum_loss_dicts(total_ld, batch_ld):
    """ Add all values in :code:`batch_ld` into the corresponding values in :code:`total_ld`

    Args
    ----
    total_ld, batch_ld : dict
        Dictionary with keys :code:`loss_fn_name` and values of dictionaries with
        - :code:`loss` corresponding to the loss value
        - :code:`count` corresponding to the normalizing factor
        - :code:`scaling_factor` corresponding to the scaling coefficient in the loss function

    Returns
    -------
    combined_ld : dict
        Combined loss dictionary with the same structure as the input dictionaries
    """
    def _weighted_loss_sum(ld1, ld2):
        """ Compute the weighted loss between two loss dictionaries """
        c1, c2 = ld1["count"], ld2["count"]
        return (ld1["loss"] * c1 + ld2["loss"] * c2)/(c1 + c2)
    # combine the two loss dictionaries
    combined_ld = total_ld.copy()
    for loss_fn_name, batch_subld in batch_ld.items():
        if loss_fn_name not in combined_ld.keys():
            combined_ld[loss_fn_name] = batch_subld
        else:
            combined_subld = combined_ld[loss_fn_name]
            assert combined_subld["scaling_factor"] == batch_subld["scaling_factor"]
            combined_subld["loss"] = _weighted_loss_sum(combined_subld, batch_subld)
            combined_subld["count"] += batch_subld["count"]
    return combined_ld


[docs]def run_epoch(model, dataloader, loss_fn, optimizer=None, scheduler=None, grad=False, test=False, dev="cuda:0", isDataParallel=False, finetune=False): """ Run :code:`model` on one epoch of :code:`dataloader` Args ---- model : terminator.model.TERMinator.TERMinator An instance of TERMinator dataloader : torch.utils.data.DataLoader A torch DataLoader that wraps either terminator.data.data.TERMDataLoader or terminator.data.data.TERMLazyDataLoader loss_fn : function Loss function with signature :code:`loss_fn(etab, E_idx, data)` and returns :code`loss, batch_count`, where - :code:`etab, E_idx` is the outputted Potts Model - :code:`data` is the input data produced by :code:`dataloader` - :code:`hparams` is the model hyperparameters - :code:`loss` is the loss value - :code:`batch_count` is the averaging factor optimizer : torch optimizer or None An optimizer for :code:`model`. Used when :code:`grad=True, test=False` scheduler : torch scheduler or None The associted scheduler for the given optimizer grad : bool Whether or not to compute gradients. :code:`True` to train the model, :code:`False` to use model in evaluation mode. test : bool Whether or not to save the outputs of the model. Requires :code:`grad=False`. dev : str, default="cuda:0" What device to compute on Returns ------- epoch_loss : float Loss on the run epoch running_loss_dict : dict Loss breakdown into component sublosses and scaling factors of epoch_loss dump : list of dicts, conditionally present Outputs of the model. Present when :code:`test=True` """ # arg checking if test: assert not grad, "grad should not be on for test set" if grad: assert optimizer is not None, "require an optimizer if using grads" if scheduler is not None: assert optimizer is not None, "using a scheduler requires an optimizer" running_loss_dict = {} # set grads properly if grad: model.train() if finetune: # freeze all but the last output layer if isDataParallel: # TODO cleaner way to do this? for (name, module) in model.module.named_children(): if name == "top": for (n, m) in module.named_children(): if n == "W_out": m.requires_grad = True else: m.requires_grad = False else: module.requires_grad = False else: for (name, module) in model.named_children(): if name == "top": for (n, m) in module.named_children(): if n == "W_out": m.requires_grad = True else: m.requires_grad = False else: module.requires_grad = False else: torch.set_grad_enabled(True) else: model.eval() torch.set_grad_enabled(False) # record inference outputs if necessary if test: dump = [] progress = tqdm(total=len(dataloader)) for data in dataloader: # a small hack for DataParallel to know which device got which proteins data['scatter_idx'] = torch.arange(len(data['seq_lens'])) _to_dev(data, dev) max_seq_len = max(data['seq_lens'].tolist()) ids = data['ids'] try: etab, E_idx = model(data, max_seq_len) batch_loss_dict = loss_fn(etab, E_idx, data) loss = _compute_loss(batch_loss_dict) except Exception as e: print(ids) raise e if grad: optimizer.zero_grad() loss.backward() optimizer.step() running_loss_dict = _sum_loss_dicts(running_loss_dict, _ld_item_values(batch_loss_dict)) if test: n_batch, l, n = etab.shape[:3] dump.append({ 'loss': loss, 'out': etab.view(n_batch, l, n, 20, 20).cpu().numpy(), 'idx': E_idx.cpu().numpy(), 'ids': ids }) term_mask_eff = int((~data['src_key_mask']).sum().item() / data['src_key_mask'].numel() * 100) res_mask_eff = int(data['x_mask'].sum().item() / data['x_mask'].numel() * 100) # compactify what's printed to stdout loss_breakdown = {} for loss_fn_name, subloss_dict in running_loss_dict.items(): loss_breakdown[loss_fn_name] = round(subloss_dict["loss"], 2) avg_loss = round(_compute_loss(running_loss_dict), 2) progress.update(1) progress.refresh() progress.set_description_str(f'avg loss {avg_loss} {loss_breakdown} | eff 0.{term_mask_eff}, 0.{res_mask_eff}') progress.close() epoch_loss = _compute_loss(running_loss_dict) if scheduler is not None: scheduler.step(epoch_loss) torch.set_grad_enabled(True) # just to be safe if test: return epoch_loss, running_loss_dict, dump return epoch_loss, running_loss_dict, None