terminator.utils.model.loop_utils.run_epoch

terminator.utils.model.loop_utils.run_epoch(model, dataloader, loss_fn, optimizer=None, scheduler=None, grad=False, test=False, dev='cuda:0', isDataParallel=False, finetune=False)[source]

Run model on one epoch of dataloader

Parameters:
  • model (terminator.model.TERMinator.TERMinator) – An instance of TERMinator

  • dataloader (torch.utils.data.DataLoader) – A torch DataLoader that wraps either terminator.data.data.TERMDataLoader or terminator.data.data.TERMLazyDataLoader

  • loss_fn (function) – Loss function with signature loss_fn(etab, E_idx, data) and returns :code`loss, batch_count`, where - etab, E_idx is the outputted Potts Model - data is the input data produced by dataloader - hparams is the model hyperparameters - loss is the loss value - batch_count is the averaging factor

  • optimizer (torch optimizer or None) – An optimizer for model. Used when grad=True, test=False

  • scheduler (torch scheduler or None) – The associted scheduler for the given optimizer

  • grad (bool) – Whether or not to compute gradients. True to train the model, False to use model in evaluation mode.

  • test (bool) – Whether or not to save the outputs of the model. Requires grad=False.

  • dev (str, default=”cuda:0”) – What device to compute on

Returns:

  • epoch_loss (float) – Loss on the run epoch

  • running_loss_dict (dict) – Loss breakdown into component sublosses and scaling factors of epoch_loss

  • dump (list of dicts, conditionally present) – Outputs of the model. Present when test=True