terminator.utils.model.loop_utils.run_epoch¶
- terminator.utils.model.loop_utils.run_epoch(model, dataloader, loss_fn, optimizer=None, scheduler=None, grad=False, test=False, dev='cuda:0', isDataParallel=False, finetune=False)[source]¶
Run
modelon one epoch ofdataloader- Parameters:
model (terminator.model.TERMinator.TERMinator) – An instance of TERMinator
dataloader (torch.utils.data.DataLoader) – A torch DataLoader that wraps either terminator.data.data.TERMDataLoader or terminator.data.data.TERMLazyDataLoader
loss_fn (function) – Loss function with signature
loss_fn(etab, E_idx, data)and returns :code`loss, batch_count`, where -etab, E_idxis the outputted Potts Model -datais the input data produced bydataloader-hparamsis the model hyperparameters -lossis the loss value -batch_countis the averaging factoroptimizer (torch optimizer or None) – An optimizer for
model. Used whengrad=True, test=Falsescheduler (torch scheduler or None) – The associted scheduler for the given optimizer
grad (bool) – Whether or not to compute gradients.
Trueto train the model,Falseto use model in evaluation mode.test (bool) – Whether or not to save the outputs of the model. Requires
grad=False.dev (str, default=”cuda:0”) – What device to compute on
- Returns:
epoch_loss (float) – Loss on the run epoch
running_loss_dict (dict) – Loss breakdown into component sublosses and scaling factors of epoch_loss
dump (list of dicts, conditionally present) – Outputs of the model. Present when
test=True