Source code for engine.training_engine

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import copy
import gc
import shutil
import time
import traceback
from itertools import product
from typing import Dict, Union

import numpy as np
import torch
from torch import Tensor
from torch.nn import functional as F

from common import DEFAULT_EPOCHS, DEFAULT_ITERATIONS, DEFAULT_LOG_FREQ, if_test_env
from data.transforms.image_torch import apply_mixing_transforms
from engine.utils import autocast_fn, get_batch_size, get_log_writers, log_metrics
from loss_landscape import landscape_utils as ll_utils
from metrics.stats import Statistics
from options.parse_args import parse_validation_metric_names
from utils import logger
from utils.checkpoint_utils import (
    copy_weights,
    save_checkpoint,
    save_interval_checkpoint,
)
from utils.common_utils import move_to_device, unwrap_model_fn
from utils.ddp_utils import dist_barrier, is_master
from utils.tensor_utils import reduce_tensor_sum


[docs]class Trainer(object): """ This class defines the training and validation code for training models with CVNets """
[docs] def __init__( self, opts, model, validation_loader, training_loader, criterion, optimizer, scheduler, gradient_scaler, start_epoch: int = 0, start_iteration: int = 0, best_metric: float = 0.0, model_ema=None, *args, **kwargs, ) -> None: super(Trainer, self).__init__() self.opts = opts self.model = model self.model_ema = model_ema self.criteria = criterion self.optimizer = optimizer self.scheduler = scheduler self.gradient_scaler = gradient_scaler self.val_loader = validation_loader self.train_loader = training_loader self.device = getattr(opts, "dev.device", torch.device("cpu")) self.start_epoch = start_epoch self.best_metric = best_metric self.train_iterations = start_iteration self.is_master_node = is_master(opts) self.max_iterations_reached = False self.max_iterations = getattr( self.opts, "scheduler.max_iterations", DEFAULT_ITERATIONS ) self.use_distributed = getattr(self.opts, "ddp.use_distributed", False) self.log_freq = getattr(self.opts, "common.log_freq", DEFAULT_LOG_FREQ) self.accum_freq = getattr(self.opts, "common.accum_freq", 1) self.accum_after_epoch = getattr(self.opts, "common.accum_after_epoch", 0) self.mixed_precision_training = getattr(opts, "common.mixed_precision", False) self.mixed_precision_dtype = getattr( opts, "common.mixed_precision_dtype", "float16" ) self.train_metric_names = getattr(opts, "stats.train", ["loss"]) if isinstance(self.train_metric_names, str): self.train_metric_names = [self.train_metric_names] assert isinstance( self.train_metric_names, list ), "Type of metric names should be list. Got: {}".format( type(self.train_metric_names) ) if "loss" not in self.train_metric_names: self.train_metric_names.append(self.train_metric_names) ( self.val_metric_names, self.ckpt_metric, self.ckpt_submetric, ) = parse_validation_metric_names(self.opts) self.save_all_checkpoints = getattr( self.opts, "common.save_all_checkpoints", False ) self.save_location = getattr(opts, "common.exp_loc", "results/run_1") self.log_writers = get_log_writers(self.opts, save_location=self.save_location) self.adjust_norm_mom = None if getattr(opts, "model.normalization.adjust_bn_momentum.enable", False): from cvnets.layers import AdjustBatchNormMomentum self.adjust_norm_mom = AdjustBatchNormMomentum(opts=opts) if self.is_master_node: logger.log( "Batch normalization momentum will be annealed during training." ) print(self.adjust_norm_mom) # sample-efficient training self.cache_dict = None self.sample_efficient_training = getattr( opts, "dataset.sample_efficient_training.enable", False ) self.sample_confidence = getattr( opts, "dataset.sample_efficient_training.sample_confidence", 0.5 ) self.find_easy_samples_every_k_epoch = getattr( opts, "dataset.sample_efficient_training.find_easy_samples_every_k_epochs", 5, ) self.min_sample_frequency = getattr( opts, "dataset.sample_efficient_training.min_sample_frequency", 5 ) if self.sample_efficient_training: self.train_loader_set = copy.deepcopy(self.train_loader) self.sample_ids_orig = self.train_loader_set.get_sample_indices() n_samples = len(self.sample_ids_orig) self.running_sum_tensor = torch.zeros( (n_samples,), device=self.device, dtype=torch.int ) self.running_sum_tensor.requires_grad = False if self.is_master_node: logger.log("Configuring for sample efficient training") # recent versions of PyTorch support setting grads to None, for better performance # To be explored in Future # self.optimizer.zero_grad(set_to_none=True) self.set_grad_to_none = False save_interval_freq = getattr(opts, "common.save_interval_freq", 0) # save interval checkpoints every `save_interval_freq` updates on the master node self.save_interval = self.is_master_node and save_interval_freq > 0 self.save_interval_freq = save_interval_freq
[docs] def compute_grad_norm(self): parameters = [p for p in self.model.parameters() if p.grad is not None] if len(parameters) == 0: return None norm_type = 2.0 # L2 norm inv_scale = 1.0 / self.gradient_scaler.get_scale() total_norm = torch.norm( torch.stack( [ torch.norm(p.grad.detach() * inv_scale, norm_type).to(self.device) for p in parameters ] ), norm_type, ) if total_norm.isnan() or total_norm.isinf(): return None return total_norm
def _zero_grad(self): if self.set_grad_to_none: self.optimizer.zero_grad(set_to_none=True) else: self.optimizer.zero_grad()
[docs] def train_epoch(self, epoch): time.sleep( if_test_env(0.5, otherwise=2) ) # To prevent possible deadlock during epoch transition if self.is_master_node: logger.double_dash_line() logger.debug( "Training epoch {} with {} samples".format( epoch, self.train_loader.samples_in_dataset() ) ) train_stats = Statistics( opts=self.opts, metric_names=self.train_metric_names, is_master_node=self.is_master_node, is_distributed=self.use_distributed, log_writers=self.log_writers, ) self.model.train() # criteria is also a nn.Module and we may need access to training property in some # loss functions. So, to enable, that, we set criteria to train/eval mode self.criteria.train() accum_freq = self.accum_freq if epoch >= self.accum_after_epoch else 1 max_norm = getattr(self.opts, "common.grad_clip", None) # set the gradient to zero or None self._zero_grad() epoch_start_time = time.time() batch_load_start = time.time() grad_norm = torch.tensor([0.0], dtype=torch.float, device=self.device) for batch_id, batch in enumerate(self.train_loader): if self.train_iterations > self.max_iterations: self.max_iterations_reached = True return -1, -1 # move to device batch = move_to_device(opts=self.opts, x=batch, device=self.device) # apply mix-up transforms if any batch = apply_mixing_transforms(opts=self.opts, data=batch) batch_load_toc = time.time() - batch_load_start samples, targets = batch["samples"], batch["targets"] batch_size = get_batch_size(samples) # update the learning rate self.optimizer = self.scheduler.update_lr( optimizer=self.optimizer, epoch=epoch, curr_iter=self.train_iterations ) # adjust bn momentum if self.adjust_norm_mom is not None: self.adjust_norm_mom.adjust_momentum( model=self.model, epoch=epoch, iteration=self.train_iterations ) with autocast_fn( enabled=self.mixed_precision_training, amp_precision=self.mixed_precision_dtype, ): # prediction pred_label = self.model(samples) # compute loss loss_dict_or_tensor: Union[Dict, Tensor] = self.criteria( input_sample=samples, prediction=pred_label, target=targets, epoch=epoch, iterations=self.train_iterations, ) if isinstance(loss_dict_or_tensor, Dict): if "total_loss" not in loss_dict_or_tensor.keys(): logger.error( "total_loss key is required for loss functions that return outputs as dictionary." ) loss = loss_dict_or_tensor["total_loss"] elif isinstance(loss_dict_or_tensor, Tensor): loss = loss_dict_or_tensor else: logger.error("Loss value should be an instance of Tensor or Dict") if isinstance(loss, torch.Tensor) and torch.isnan(loss): logger.error("Nan encountered in the loss.") # perform the backward pass with gradient accumulation [Optional] self.gradient_scaler.scale(loss).backward() if (batch_id + 1) % accum_freq == 0: if max_norm is not None: # For gradient clipping, unscale the gradients and then clip them self.gradient_scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_( self.model.parameters(), max_norm=max_norm ) if "grad_norm" in self.train_metric_names: # compute grad_norm for logging purposes. # We can't use the output of clip_grad_norm_ because it returns the total norm before clipping grad_norm = self.compute_grad_norm() # optimizer step self.gradient_scaler.step(optimizer=self.optimizer) # update the scale for next batch self.gradient_scaler.update() # set the gradient to zero or None self._zero_grad() self.train_iterations += 1 if self.model_ema is not None: self.model_ema.update_parameters(self.model) train_stats.update( pred_label=pred_label, target_label=targets, extras={"loss": loss_dict_or_tensor, "grad_norm": grad_norm}, batch_time=batch_load_toc, batch_size=batch_size, ) # save the checkpoint every N updates if ( self.save_interval and (self.train_iterations % self.save_interval_freq) == 0 ): save_interval_checkpoint( iterations=self.train_iterations, epoch=epoch, model=self.model, optimizer=self.optimizer, best_metric=loss.item(), save_dir=self.save_location, gradient_scaler=self.gradient_scaler, model_ema=self.model_ema, ) logger.info( "Checkpoints saved after {} updates at: {}".format( self.train_iterations, self.save_location ), print_line=True, ) if batch_id % self.log_freq == 0 and self.is_master_node: lr = self.scheduler.retrieve_lr(self.optimizer) train_stats.iter_summary( epoch=epoch, n_processed_samples=self.train_iterations, total_samples=self.max_iterations, learning_rate=lr, elapsed_time=epoch_start_time, ) batch_load_start = time.time() avg_loss = train_stats.avg_statistics( metric_name="loss", sub_metric_name="total_loss" ) train_stats.epoch_summary(epoch=epoch, stage="training") avg_ckpt_metric = train_stats.avg_statistics( metric_name=self.ckpt_metric, sub_metric_name=self.ckpt_submetric ) gc.collect() return avg_loss, avg_ckpt_metric
[docs] def val_epoch(self, epoch, model, extra_str=""): if self.val_loader is None: return 0.0, 0.0 time.sleep( if_test_env(0.5, otherwise=2) ) # To prevent possible deadlock during epoch transition validation_stats = Statistics( opts=self.opts, metric_names=self.val_metric_names, is_master_node=self.is_master_node, is_distributed=self.use_distributed, log_writers=self.log_writers, ) model.eval() # criteria is also a nn.Module and we may need access to training property in some # loss functions. So, to enable, that, we set criteria to train/eval mode self.criteria.eval() if model.training: if self.is_master_node: logger.warning( "Model is in training mode. Switching to evaluation mode" ) model.eval() if self.criteria.training: self.criteria.eval() with torch.no_grad(): epoch_start_time = time.time() total_samples = len(self.val_loader) processed_samples = 0 lr = self.scheduler.retrieve_lr(self.optimizer) for batch_id, batch in enumerate(self.val_loader): batch = move_to_device(opts=self.opts, x=batch, device=self.device) samples, targets = batch["samples"], batch["targets"] batch_size = get_batch_size(samples) with autocast_fn( enabled=self.mixed_precision_training, amp_precision=self.mixed_precision_dtype, ): # prediction pred_label = model(samples) # compute loss loss_dict_or_tensor = self.criteria( input_sample=samples, prediction=pred_label, target=targets, ) processed_samples += batch_size validation_stats.update( pred_label=pred_label, target_label=targets, extras={"loss": loss_dict_or_tensor}, batch_time=0.0, batch_size=batch_size, ) if batch_id % self.log_freq == 0 and self.is_master_node: validation_stats.iter_summary( epoch=epoch, n_processed_samples=processed_samples, total_samples=total_samples, elapsed_time=epoch_start_time, learning_rate=lr, ) validation_stats.epoch_summary(epoch=epoch, stage="validation" + extra_str) avg_loss = validation_stats.avg_statistics( metric_name="loss", sub_metric_name="total_loss" ) avg_ckpt_metric = validation_stats.avg_statistics( metric_name=self.ckpt_metric, sub_metric_name=self.ckpt_submetric ) if avg_ckpt_metric is None: avg_ckpt_metric = avg_loss gc.collect() return avg_loss, avg_ckpt_metric
[docs] def find_easy_samples(self, epoch, model, *args, **kwargs): """ This function identifies easy samples in the training set and removes them from training. .. note:: Currently, this is implemented separately to avoid breaking the training and validation pipeline. In future, this will be combined with main training loop to reduce overhead. """ time.sleep( if_test_env(0.5, otherwise=2) ) # To prevent possible deadlock during epoch transition model.eval() if model.training and self.is_master_node: logger.warning("Model is in training mode. Switching to evaluation mode") model.eval() if self.is_master_node: logger.log("Trying to find easy samples in epoch {}".format(epoch)) with torch.no_grad(): easy_sample_ids_tensor = torch.zeros_like(self.running_sum_tensor) for batch_id, batch in enumerate(self.train_loader_set): batch = move_to_device(opts=self.opts, x=batch, device=self.device) samples, targets = batch["samples"], batch["targets"] sample_ids = None if "sample_id" in batch: sample_ids = batch["sample_id"] else: self.sample_efficient_training = False if self.is_master_node: logger.log( "Sample Ids are required in a batch for sample efficient training. " "sample_id key not found in batch. Disabling sample efficient training." ) break if sample_ids is None: logger.log("Sample Ids can't be none") break with autocast_fn( enabled=self.mixed_precision_training, amp_precision=self.mixed_precision_dtype, ): # prediction pred_label = model(samples) pred_label = F.softmax(pred_label, dim=-1) pred_conf, pred_indices = torch.max(pred_label, dim=-1) easy_samples = torch.logical_and( pred_indices.eq( targets ), # condition 1: Predicted label == Target label pred_conf >= self.sample_confidence, # condition 2: prediction confidence >= desired confidence ) if easy_samples.numel() > 0: easy_sample_ids = sample_ids[easy_samples] # find easy samples as per condition 1 and 2 and set their values to 1 easy_sample_ids_tensor[easy_sample_ids] = 1 # synchronize tensors if self.use_distributed: # sync across all GPUs. easy_sample_ids_tensor = reduce_tensor_sum(easy_sample_ids_tensor) # some samples which are classified easy earlier may have been classified hard now. easy_sample_ids_tensor[easy_sample_ids_tensor == 0] = -1 if self.is_master_node: logger.debug( "Number of easy samples found during epoch {} are {}".format( epoch, easy_sample_ids_tensor[easy_sample_ids_tensor > 0].sum().item(), ) ) self.running_sum_tensor = torch.clip( self.running_sum_tensor + easy_sample_ids_tensor, min=0, max=self.min_sample_frequency, ) if self.running_sum_tensor.sum() > 0: skip_sample_ids = ( self.running_sum_tensor >= self.min_sample_frequency ).nonzero(as_tuple=True)[0] if skip_sample_ids.numel() > 0: skip_samples = skip_sample_ids.cpu().numpy().tolist() new_sample_ids = [ s_id for s_id in self.sample_ids_orig if s_id not in skip_sample_ids ] # update the train loader indices self.train_loader.update_indices(new_sample_ids) if self.is_master_node: logger.debug( "Number of samples to skip after epoch {} are {}".format( epoch, len(skip_samples) ) )
[docs] def run(self, train_sampler=None): if train_sampler is None and self.is_master_node: logger.error("Train sampler cannot be None") copy_at_epoch = getattr(self.opts, "ema.copy_at_epoch", -1) train_start_time = time.time() cfg_file = getattr(self.opts, "common.config_file", None) if cfg_file is not None and self.is_master_node: dst_cfg_file = "{}/config.yaml".format(self.save_location) shutil.copy(src=cfg_file, dst=dst_cfg_file) logger.info( "Configuration file is stored here: {}".format( logger.color_text(dst_cfg_file) ) ) keep_k_best_ckpts = getattr(self.opts, "common.k_best_checkpoints", 5) ema_best_metric = self.best_metric is_ema_best = False try: max_epochs = getattr(self.opts, "scheduler.max_epochs", DEFAULT_EPOCHS) max_checkpoint_metric = getattr( self.opts, "stats.checkpoint_metric_max", False ) for epoch in range(self.start_epoch, max_epochs): # Note that we are using our owm implementations of data samplers # and we have defined this function for both distributed and non-distributed cases train_sampler.set_epoch(epoch) train_sampler.update_scales( epoch=epoch, is_master_node=self.is_master_node ) train_loss, train_ckpt_metric = self.train_epoch(epoch) val_loss, val_ckpt_metric = self.val_epoch( epoch=epoch, model=self.model ) if epoch == copy_at_epoch and self.model_ema is not None: if self.is_master_node: logger.log("Copying EMA weights") # copy model_src weights to model_tgt self.model = copy_weights( model_tgt=self.model, model_src=self.model_ema ) if self.is_master_node: logger.log("EMA weights copied") logger.log("Running validation after Copying EMA model weights") self.val_epoch(epoch=epoch, model=self.model) if max_checkpoint_metric: is_best = val_ckpt_metric >= self.best_metric self.best_metric = max(val_ckpt_metric, self.best_metric) else: is_best = val_ckpt_metric <= self.best_metric self.best_metric = min(val_ckpt_metric, self.best_metric) val_ema_loss = None val_ema_ckpt_metric = None if self.model_ema is not None: val_ema_loss, val_ema_ckpt_metric = self.val_epoch( epoch=epoch, model=self.model_ema.ema_model, extra_str=" (EMA)" ) if max_checkpoint_metric: is_ema_best = val_ema_ckpt_metric >= ema_best_metric ema_best_metric = max(val_ema_ckpt_metric, ema_best_metric) else: is_ema_best = val_ema_ckpt_metric <= ema_best_metric ema_best_metric = min(val_ema_ckpt_metric, ema_best_metric) # sample efficient training if ( self.sample_efficient_training and (epoch + 1) % self.find_easy_samples_every_k_epoch == 0 ): self.find_easy_samples( epoch=epoch, model=self.model if self.model_ema is not None else self.model_ema.ema_model, ) gc.collect() if self.is_master_node: save_checkpoint( iterations=self.train_iterations, epoch=epoch, model=self.model, optimizer=self.optimizer, best_metric=self.best_metric, is_best=is_best, save_dir=self.save_location, model_ema=self.model_ema, is_ema_best=is_ema_best, ema_best_metric=ema_best_metric, gradient_scaler=self.gradient_scaler, max_ckpt_metric=max_checkpoint_metric, k_best_checkpoints=keep_k_best_ckpts, save_all_checkpoints=self.save_all_checkpoints, ) logger.info( "Checkpoints saved at: {}".format(self.save_location), print_line=True, ) if self.is_master_node: lr_list = self.scheduler.retrieve_lr(self.optimizer) for log_writer in self.log_writers: log_metrics( lrs=lr_list, log_writer=log_writer, train_loss=train_loss, val_loss=val_loss, epoch=epoch, best_metric=self.best_metric, val_ema_loss=val_ema_loss, ckpt_metric_name=self.ckpt_metric, train_ckpt_metric=train_ckpt_metric, val_ckpt_metric=val_ckpt_metric, val_ema_ckpt_metric=val_ema_ckpt_metric, ) if self.max_iterations_reached: if self.use_distributed: dist_barrier() if self.is_master_node: logger.info("Max. iterations for training reached") break except KeyboardInterrupt as e: if self.is_master_node: logger.log("Keyboard interruption. Exiting from early training") raise e except Exception as e: if "out of memory" in str(e): logger.log("OOM exception occured") n_gpus = getattr(self.opts, "dev.num_gpus", 1) for dev_id in range(n_gpus): mem_summary = torch.cuda.memory_summary( device=torch.device("cuda:{}".format(dev_id)), abbreviated=True ) logger.log("Memory summary for device id: {}".format(dev_id)) print(mem_summary) logger.log( f"Exception occurred that interrupted the training:\n{traceback.format_exc()}" ) raise e finally: use_distributed = getattr(self.opts, "ddp.use_distributed", False) if use_distributed: torch.distributed.destroy_process_group() torch.cuda.empty_cache() for log_writer in self.log_writers: log_writer.close() if self.is_master_node: train_end_time = time.time() hours, rem = divmod(train_end_time - train_start_time, 3600) minutes, seconds = divmod(rem, 60) train_time_str = "{:0>2}:{:0>2}:{:05.2f}".format( int(hours), int(minutes), seconds ) logger.log("Training took {}".format(train_time_str))
[docs] def run_loss_landscape(self): # Loss landscape code is adapted from https://github.com/xxxnell/how-do-vits-work ll_start_time = time.time() try: n_points = getattr(self.opts, "loss_landscape.n_points", 32) min_x = getattr(self.opts, "loss_landscape.min_x", -1.0) max_x = getattr(self.opts, "loss_landscape.max_x", 1.0) min_y = getattr(self.opts, "loss_landscape.min_y", -1.0) max_y = getattr(self.opts, "loss_landscape.max_y", 1.0) if self.is_master_node: logger.log( "Loss landscape coord space params: \n\tmin_x={}\n\tmax_x={}\n\tmin_y={}\n\tmax_y={}\n\tn_points={}".format( min_x, max_x, min_y, max_y, n_points ) ) ll_metrics = ["loss"] ll_stats = Statistics( opts=self.opts, metric_names=ll_metrics, is_master_node=self.is_master_node, is_distributed=self.use_distributed, log_writers=self.log_writers, ) has_module = hasattr(self.model, "module") unwrapped_model = unwrap_model_fn(self.model) model_name = unwrapped_model.__class__.__name__ # copy the model and create bases model = copy.deepcopy(self.model) weight_state_0 = unwrapped_model.state_dict() bases = ll_utils.create_bases( model=model, device=self.device, has_module=has_module ) xs = np.linspace(min_x, max_x, n_points) ys = np.linspace(min_y, max_y, n_points) grid_a, grid_b = np.meshgrid(xs, ys, indexing="xy") loss_surface = np.empty_like(grid_a) epoch = -1 for coord_a, coord_b in product(range(n_points), range(n_points)): epoch += 1 coords_list = [grid_a[coord_a, coord_b], grid_b[coord_a, coord_b]] weight_state_1 = copy.deepcopy(weight_state_0) gs = [{k: r * bs[k] for k in bs} for r, bs in zip(coords_list, bases)] gs = { k: torch.sum(torch.stack([g[k] for g in gs]), dim=0) + weight_state_1[k] for k in gs[0] } # load the weights unwrapped_model.load_state_dict(gs) model = model.to(device=self.device) model.eval() total_samples = len(self.val_loader) with torch.no_grad(): epoch_start_time = time.time() processed_samples = 0 for batch_id, batch in enumerate(self.val_loader): batch = move_to_device( opts=self.opts, x=batch, device=self.device ) samples, targets = batch["samples"], batch["targets"] batch_size = get_batch_size(samples) processed_samples += batch_size # make the prediction and compute loss pred_label = model(samples) loss_dict_or_tensor: Union[Dict, Tensor] = self.criteria( input_sample=samples, prediction=pred_label, target=targets, ) if isinstance(loss_dict_or_tensor, Dict): if "total_loss" not in loss_dict_or_tensor.keys(): logger.error( "total_loss key is required for loss functions that return outputs as dictionary." ) loss = loss_dict_or_tensor["total_loss"] elif isinstance(loss_dict_or_tensor, Tensor): loss = loss_dict_or_tensor else: logger.error( "Loss value should be an instance of Tensor or Dict" ) if isinstance(loss, torch.Tensor) and torch.isnan(loss): logger.error("Nan encountered in the loss.") ll_stats.update( pred_label=pred_label, target_label=targets, extras={"loss": loss_dict_or_tensor}, batch_time=0.0, batch_size=batch_size, ) if batch_id % self.log_freq == 0 and self.is_master_node: ll_stats.iter_summary( epoch=epoch, n_processed_samples=processed_samples, total_samples=total_samples, elapsed_time=epoch_start_time, learning_rate=0.0, ) avg_loss = ll_stats.avg_statistics( metric_name="loss", sub_metric_name="total_loss" ) loss_surface[coord_a, coord_b] = avg_loss if self.is_master_node: print( "x: {:.2f}, y: {:.2f}, loss: {:.2f}".format( coords_list[0], coords_list[1], avg_loss ) ) if self.is_master_node: lr_list = [0.0] for log_writer in self.log_writers: log_metrics( lrs=lr_list, log_writer=log_writer, train_loss=0.0, val_loss=avg_loss, epoch=epoch, best_metric=self.best_metric, val_ema_loss=None, ckpt_metric_name=None, train_ckpt_metric=None, val_ckpt_metric=None, val_ema_ckpt_metric=None, ) gc.collect() # take a small nap time.sleep(if_test_env(0, otherwise=1)) if self.is_master_node: ll_utils.plot_save_graphs( save_dir=self.save_location, model_name=model_name, grid_a=grid_a, grid_b=grid_b, loss_surface=loss_surface, resolution=n_points, ) except KeyboardInterrupt as e: if self.is_master_node: logger.log("Keyboard interruption. Exiting from early training") raise e except Exception as e: if "out of memory" in str(e): logger.log("OOM exception occured") n_gpus = getattr(self.opts, "dev.num_gpus", 1) for dev_id in range(n_gpus): mem_summary = torch.cuda.memory_summary( device=torch.device("cuda:{}".format(dev_id)), abbreviated=True ) logger.log("Memory summary for device id: {}".format(dev_id)) print(mem_summary) else: logger.log( "Exception occurred that interrupted the training. {}".format( str(e) ) ) print(e) raise e finally: if self.use_distributed: torch.distributed.destroy_process_group() torch.cuda.empty_cache() if self.is_master_node: ll_end_time = time.time() hours, rem = divmod(ll_end_time - ll_start_time, 3600) minutes, seconds = divmod(rem, 60) train_time_str = "{:0>2}:{:0>2}:{:05.2f}".format( int(hours), int(minutes), seconds ) logger.log("Loss landspace evaluation took {}".format(train_time_str))