Callbacks#

A callback provides hooks into the training process. Different methods provides hooks into different stages of the central training loop.

class pfl.callback.TrainingProcessCallback#

Base class for callbacks.

on_train_begin(*, model)#

Called before the first central iteration.

Return type:

Metrics

after_central_iteration(aggregate_metrics, model, *, central_iteration)#

Finalize any computations after each central iteration.

Parameters:
  • aggregate_metrics (Metrics) – A Metrics object with aggregated metrics accumulated from local training on users and central updates of the model.

  • model (TypeVar(ModelType, bound= Model)) – A reference to the Model that is trained.

  • central_iteration (int) – The current central iteration number.

Return type:

Tuple[bool, Metrics]

Returns:

A tuple. The first value returned is a boolean, signaling that training should be interrupted if True. Can be useful for implementing features with early stopping or convergence criteria. The second value returned is new metrics. Do not include any of the aggregate_metrics!

on_train_end(*, model)#

Called at the end of training.

Return type:

None

class pfl.callback.RestoreTrainingCallback(saveables, checkpoint_dir, checkpoint_frequency=1)#

Add fault-tolerance to your training. If the training run fails and you restart it, this callback will restore all recent checkpoints of the saveables before starting training again. Be careful if you’ve implemented any stateful component, these will only be restored if you’ve properly implemented the Saveable interface on the component and input it to this callback. For restoring a checkpoint, it is assumed that all saveables were successfully stored in the last attempt.

Parameters:
  • saveables (List[Saveable]) – The objects that need to save their states so that they can be loaded if training is interrupted and then resumed.

  • checkpoint_dir (Union[str, List[str]]) – Root dir for where to store the saveables’ states. Let this be a list of directory paths to specify a unique checkpoint directory for each saveable. Location will be relative to root dir on current platform.

  • checkpoint_frequency (int) – Save checkpoints of saveables every this many iterations.

on_train_begin(*, model)#

Restore from previous run’s checkpoints if exists.

Return type:

Metrics

after_central_iteration(aggregate_metrics, model, *, central_iteration)#

Finalize any computations after each central iteration.

Parameters:
  • aggregate_metrics (Metrics) – A Metrics object with aggregated metrics accumulated from local training on users and central updates of the model.

  • model (TypeVar(ModelType, bound= Model)) – A reference to the Model that is trained.

  • central_iteration (int) – The current central iteration number.

Return type:

Tuple[bool, Metrics]

Returns:

A tuple. The first value returned is a boolean, signaling that training should be interrupted if True. Can be useful for implementing features with early stopping or convergence criteria. The second value returned is new metrics. Do not include any of the aggregate_metrics!

class pfl.callback.CentralEvaluationCallback(dataset, model_eval_params=None, frequency=1, distribute_evaluation=True, format_fn=None)#

Callback for performing evaluation on a centrally held dataset in between central iterations. The first evaluation is done before training begins.

Parameters:
  • dataset (TypeVar(AbstractDatasetType, bound= AbstractDataset)) – A Dataset that represents a central dataset. It has nothing to do with a user. The class Dataset is solely used to properly plug in to pfl.

  • model_eval_params (Optional[ModelHyperParams]) – The model parameters to use when evaluating the model. Can be None if the model doesn’t require hyperparameters for evaluation.

  • frequency (int) – Perform central evaluation every frequency central iterations.

  • distribute_evaluation (bool) – Evaluate by distributing the computation across each worker used. If set to false, each worker runs evaluation independently. This will take longer to run than distributed evaluation. However, it may be necessary to disable distributed evaluation for some models and features, which do not support this mode.

  • format_fn (Optional[Callable[[str], MetricName]]) –

    A callable (metric_name) -> MetricName that formats the metric string name metric_name into a pfl metric name representation. The default value is

    lambda n: StringMetricName(f'Central val | {n}')
    

    It can be necessary to override the default when you are using multiple instances of this class, otherwise the metric names might conflict with each other.

on_train_begin(*, model)#

Called before the first central iteration.

Return type:

Metrics

after_central_iteration(aggregate_metrics, model, *, central_iteration)#

Finalize any computations after each central iteration.

Parameters:
  • aggregate_metrics (Metrics) – A Metrics object with aggregated metrics accumulated from local training on users and central updates of the model.

  • model (TypeVar(EvaluatableModelType, bound= EvaluatableModel)) – A reference to the Model that is trained.

  • central_iteration (int) – The current central iteration number.

Return type:

Tuple[bool, Metrics]

Returns:

A tuple. The first value returned is a boolean, signaling that training should be interrupted if True. Can be useful for implementing features with early stopping or convergence criteria. The second value returned is new metrics. Do not include any of the aggregate_metrics!

class pfl.callback.CentralEvaluationWithEMACallback(dataset, ema, model_eval_params=None, frequency=1, distribute_evaluation=True, format_fn=None)#

Callback for performing evaluation with the exponential moving average of trained model on a centrally held dataset in between central iterations. The callback will update the EMA parameters after each central iteration, and will assign the EMA parameters to the model for evaluation.

Parameters:
  • dataset (TypeVar(AbstractDatasetType, bound= AbstractDataset)) – A Dataset that represents a central dataset. It has nothing to do with a user. The class Dataset is solely used to properly plug in to pfl.

  • ema (CentralExponentialMovingAverage) – A CentralExponentialMovingAverage that holds the EMA variables for the model to be evaluated. See CentralExponentialMovingAverage for more details.

  • model_eval_params (Optional[ModelHyperParams]) – The model parameters to use when evaluating the model.

  • frequency (int) – Perform central evaluation every frequency central iterations.

  • distribute_evaluation (bool) – Evaluate by distributing the computation across each worker used. If set to false, each worker runs evaluation independently. This will take longer to run than distributed evaluation. However, it may be necessary to disable distributed evaluation for some models and features, which do not support this mode.

  • format_fn (Optional[Callable[[str], MetricName]]) –

    A callable (metric_name) -> MetricName that formats the metric string name metric_name into a pfl metric name representation. The default value is

    lambda n: StringMetricName(f'Central val EMA | {n}')
    

    It can be necessary to override the default when you are using multiple instances of this class, otherwise the metric names might conflict with eachother.

on_train_begin(*, model)#

Called before the first central iteration.

Return type:

Metrics

after_central_iteration(aggregate_metrics, model, *, central_iteration)#

Finalize any computations after each central iteration.

Parameters:
  • aggregate_metrics (Metrics) – A Metrics object with aggregated metrics accumulated from local training on users and central updates of the model.

  • model (StatefulModel) – A reference to the Model that is trained.

  • central_iteration (int) – The current central iteration number.

Return type:

Tuple[bool, Metrics]

Returns:

A tuple. The first value returned is a boolean, signaling that training should be interrupted if True. Can be useful for implementing features with early stopping or convergence criteria. The second value returned is new metrics. Do not include any of the aggregate_metrics!

class pfl.callback.ConvergenceCallback(metric_name, patience, performance_threshold, performance_is_better)#

Track convergence using a performance measure and stop training when converged.

Convergence is defined as when the performance becomes better than a threshold and afterwards stays that way for patience iterations. If the run is terminated, a new metric is added that stores the number of data points processed until the convergence was achieved (when the metric reached the threshold for the first time).

Parameters:
  • metric_name (Union[str, StringMetricName]) – The name of the metric to track for convergence.

  • patience (int) – The run will be terminated when the metric metric_name is better than performance threshold for at least patience iterations.

  • performance_threshold (float) – The performance required to start considering whether training has converged.

  • performance_is_better (Callable[[Any, Any], bool]) – A binary function that returns true if the first argument, indicating a performance level, is “better” than the second argument. For accuracy metrics, this is normally operator.gt, since higher is better. For loss or error metrics, lower is better, and this should be set to operator.lt.

after_central_iteration(aggregate_metrics, model, *, central_iteration)#

Finalize any computations after each central iteration.

Parameters:
  • aggregate_metrics (Metrics) – A Metrics object with aggregated metrics accumulated from local training on users and central updates of the model.

  • model (TypeVar(ModelType, bound= Model)) – A reference to the Model that is trained.

  • central_iteration (int) – The current central iteration number.

Return type:

Tuple[bool, Metrics]

Returns:

A tuple. The first value returned is a boolean, signaling that training should be interrupted if True. Can be useful for implementing features with early stopping or convergence criteria. The second value returned is new metrics. Do not include any of the aggregate_metrics!

class pfl.callback.EarlyStoppingCallback(metric_name, patience, performance_is_better=<built-in function lt>)#

Implements early stopping as a callback to use in the training process. The criteria for this callback to stop training is if the metric, given by metric_name, has not reached a new best value for patience consecutive central iterations. An improvement is defined by performance_is_better.

Parameters:
  • metric_name (Union[str, StringMetricName]) – The name of the metric to track for early stopping, usually in the form of a pfl.metrics.MetricName.

  • patience (int) – Number of central iterations to wait for an improvement in the tracked metric before interrupting the training process.

  • performance_is_better (Callable[[Any, Any], bool]) – A binary function that returns true if the first argument, indicating a performance level, is “better” than the second argument. For accuracy metrics, this is normally operator.gt, since higher is better. For loss or error metrics, lower is better, and this should be set to operator.lt. It is set to operator.lt by default because you would normally perform early stopping on a loss or error metric.

after_central_iteration(aggregate_metrics, model, *, central_iteration)#

Finalize any computations after each central iteration.

Parameters:
  • aggregate_metrics (Metrics) – A Metrics object with aggregated metrics accumulated from local training on users and central updates of the model.

  • model (TypeVar(ModelType, bound= Model)) – A reference to the Model that is trained.

  • central_iteration (int) – The current central iteration number.

Return type:

Tuple[bool, Metrics]

Returns:

A tuple. The first value returned is a boolean, signaling that training should be interrupted if True. Can be useful for implementing features with early stopping or convergence criteria. The second value returned is new metrics. Do not include any of the aggregate_metrics!

class pfl.callback.StopwatchCallback(decimal_points=2, measure_round_in_minutes=False)#

Records the wall-clock time for total time spent training, time per central iteration and overall average time per central iteration.

Parameters:
  • decimal_points (int) – Number of decimal points to round the wall-clock time metrics.

  • measure_round_in_minutes (bool) – If True, measure time for central iteration in minutes, not seconds. If you want this, it means your training is very slow!

on_train_begin(*, model)#

Starts the stopwatch.

Return type:

Metrics

after_central_iteration(aggregate_metrics, model, *, central_iteration)#

Finalize any computations after each central iteration.

Parameters:
  • aggregate_metrics (Metrics) – A Metrics object with aggregated metrics accumulated from local training on users and central updates of the model.

  • model (TypeVar(ModelType, bound= Model)) – A reference to the Model that is trained.

  • central_iteration (int) – The current central iteration number.

Return type:

Tuple[bool, Metrics]

Returns:

A tuple. The first value returned is a boolean, signaling that training should be interrupted if True. Can be useful for implementing features with early stopping or convergence criteria. The second value returned is new metrics. Do not include any of the aggregate_metrics!

class pfl.callback.TensorBoardCallback(log_dir, write_weights=False, write_graph=True, tensorboard_port=None)#

Log events for TensorBoard: metrics, graph visualization, weight histograms. Launch tensorboard with command:

tensorboard --logdir=<path to log_dir>

Note

Only supported with TF (pfl.model.tensorflow.TFModel) right now.

Parameters:
  • log_dir (str) – Dir path where to store the TensorBoard log files. This path should be unique for every run if you run multiple trainings on the same machine.

  • write_weights (Union[bool, int]) –

    Save weight histograms and distributions for the layers of the model There are 3 different modes:

    • False - disable this feature.

    • True - save histograms every time the algorithm performs an evaluation iteration (evaluation_frequency in ModelHyperParams).

    • An integer - Perform every this many central iterations.

  • write_graph (bool) – Visualize the model graph in TensorBoard. Disable this to keep the size of the TensorBoard data small.

  • tensorboard_port (Optional[int]) – Port to use when hosting TensorBoard.

on_train_begin(*, model)#

Called before the first central iteration.

Return type:

Metrics

after_central_iteration(aggregate_metrics, model, *, central_iteration)#

Finalize any computations after each central iteration.

Parameters:
  • aggregate_metrics (Metrics) – A Metrics object with aggregated metrics accumulated from local training on users and central updates of the model.

  • model (TypeVar(ModelType, bound= Model)) – A reference to the Model that is trained.

  • central_iteration (int) – The current central iteration number.

Return type:

Tuple[bool, Metrics]

Returns:

A tuple. The first value returned is a boolean, signaling that training should be interrupted if True. Can be useful for implementing features with early stopping or convergence criteria. The second value returned is new metrics. Do not include any of the aggregate_metrics!

on_train_end(*, model)#

Called at the end of training.

Return type:

None

class pfl.callback.ModelCheckpointingCallback(model_checkpoint_dir, *, checkpoint_frequency=0)#

Callback to save model checkpoints. Note that the model checkpoints can also be saved as part of RestoreTrainingCallback as long as the model is Saveable and provided in the list of saveeables in the initialization of the callback.

Parameters:
  • model_checkpoint_dir (str) – A path to disk for saving the trained model. Location will be relative to root dir on current platform.

  • checkpoint_frequency (int) – The number of central iterations after which to save a model. When zero (the default), the model is saved once after training is complete.

after_central_iteration(aggregate_metrics, model, *, central_iteration)#

Finalize any computations after each central iteration.

Parameters:
  • aggregate_metrics (Metrics) – A Metrics object with aggregated metrics accumulated from local training on users and central updates of the model.

  • model (StatefulModel) – A reference to the Model that is trained.

  • central_iteration (int) – The current central iteration number.

Return type:

Tuple[bool, Metrics]

Returns:

A tuple. The first value returned is a boolean, signaling that training should be interrupted if True. Can be useful for implementing features with early stopping or convergence criteria. The second value returned is new metrics. Do not include any of the aggregate_metrics!

on_train_end(*, model)#

Called at the end of training.

Return type:

None

class pfl.callback.ProfilerCallback(frequency=None, warmup_iterations=0, dir_name='profile')#

Profiles the code using Python’s profiler, cProfile.

A profile is a set of statistics that describes how often and for how long various parts of a program are executed.

This callback can be used to independently profile iterations of an algorithm, or to profile all iterations of an algorithm together.

The profile statistics will be saved as an artifact during training. These statistics can be read and analysed using pstats:

import pstats
stats = pstats.Stats(<profile-stats-filename>)
stats.sort_stats(*keys)
stats.print_stats(*restrictions)

Alternatively, SnakeViz can be used to produce a graphical view of the profile in the browser.

Parameters:
  • frequency (Optional[int]) – Controls frequency and duration of profiling. If frequency is an integer > 0, profiling is performed per-iteration every frequency central training iterations. If frequency is None, a single profile is produced covering all central training iterations.

  • warmup_iterations (int) – Commence profiling after this number of central training iterations. If warmup_iterations > total number of central iterations, no profiling will take place.

  • dir_name (str) – Name of directory in which profiles will be saved. Location will be relative to root dir on current platform.

on_train_begin(*, model)#

Called before the first central iteration.

Return type:

Metrics

after_central_iteration(aggregate_metrics, model, *, central_iteration)#

Commence profiling of one iteration at end of previous central iteration.

central_iteration is zero-indexed.

warmup_iterations is the number of central iterations must have completed before profiling begins. This means profiling begins at end of iteration when central_iteration == warmup_iterations - 1.

Return type:

Tuple[bool, Metrics]

on_train_end(*, model)#

Called at the end of training.

Return type:

None

class pfl.callback.AggregateMetricsToDisk(output_path, frequency=1, decimal_points=6, check_existing_file=False)#

Callback to write aggregated metrics to disk with a given frequency with respect to the number of central iterations.

Parameters:
  • output_path (str) – Path to where the csv file of aggregated metrics should be written relative to the root dir on current platform.

  • frequency (int) – Write aggregated metrics to file every frequency central iterations. Can be useful to skip iterations where no evaluation is done if that is also set at a frequency.

  • check_existing_file (bool) – Throw error if output_path already exists and you don’t want to overwrite it.

on_train_begin(*, model)#

Called before the first central iteration.

Return type:

Metrics

after_central_iteration(aggregate_metrics, model, *, central_iteration)#

Finalize any computations after each central iteration.

Parameters:
  • aggregate_metrics (Metrics) – A Metrics object with aggregated metrics accumulated from local training on users and central updates of the model.

  • model (TypeVar(ModelType, bound= Model)) – A reference to the Model that is trained.

  • central_iteration (int) – The current central iteration number.

Return type:

Tuple[bool, Metrics]

Returns:

A tuple. The first value returned is a boolean, signaling that training should be interrupted if True. Can be useful for implementing features with early stopping or convergence criteria. The second value returned is new metrics. Do not include any of the aggregate_metrics!

on_train_end(*, model)#

Called at the end of training.

Return type:

None

class pfl.callback.TrackBestOverallMetrics(lower_is_better_metric_names=None, higher_is_better_metric_names=None, assert_metrics_found_within_frequency=25)#

Track the best value of given metrics over all iterations. If the specified metric names are not found for a particular central iteration, nothing will happen. Use parameter assert_metrics_found_within_frequency to assert that they must eventually be found, e.g. if you are doing central evaluation only every nth iteration.

Parameters:
  • lower_is_better_metric_names (Optional[List[Union[str, StringMetricName]]]) – A list of metric names to track. Whenever a metric with a name in this list is encountered, the lowest value of that metric seen through the history of all central iterations is returned.

  • higher_is_better_metric_names (Optional[List[Union[str, StringMetricName]]]) – Same as lower_is_better_metric_names, but for metrics where a higher value is better.

  • assert_metrics_found_within_frequency (int) – As a precaution, assert that all metrics referenced in lower_is_better_metric_names and higher_is_better_metric_names are found within this many iterations. If you e.g. misspelled a metric name or put this callback an order before the metric was generated, you will be notified.

on_train_begin(*, model)#

Called before the first central iteration.

Return type:

Metrics

after_central_iteration(aggregate_metrics, model, *, central_iteration)#

Finalize any computations after each central iteration.

Parameters:
  • aggregate_metrics (Metrics) – A Metrics object with aggregated metrics accumulated from local training on users and central updates of the model.

  • model (TypeVar(ModelType, bound= Model)) – A reference to the Model that is trained.

  • central_iteration (int) – The current central iteration number.

Return type:

Tuple[bool, Metrics]

Returns:

A tuple. The first value returned is a boolean, signaling that training should be interrupted if True. Can be useful for implementing features with early stopping or convergence criteria. The second value returned is new metrics. Do not include any of the aggregate_metrics!

class pfl.callback.WandbCallback(wandb_project_id, wandb_experiment_name=None, wandb_config=None, **wandb_kwargs)#

Callback for reporting metrics to Weights&Biases dashboard for comparing different PFL runs. This callback has basic support for logging metrics. If you seek more advanced features from the Wandb API, you should make your own callback.

See https://wandb.ai/ and https://docs.wandb.ai/ for more information on Weights&Biases.

Parameters:
  • wandb_project_id (str) – The name of the project where you’re sending the new run. If the project is not specified, the run is put in an “Uncategorized” project.

  • wandb_experiment_name (Optional[str]) – A short display name for this run. Generates a random two-word name by default.

  • wandb_config – Optional dictionary (or argparse) of parameters (e.g. hyperparameter choices) that are used to tag this run in the Wandb dashboard.

  • wandb_kwargs – Additional keyword args other than project, name and config that you can input to wandb.init, see https://docs.wandb.ai/ref/python/init for reference.

on_train_begin(*, model)#

Called before the first central iteration.

Return type:

Metrics

after_central_iteration(aggregate_metrics, model, *, central_iteration)#

Submits metrics of this central iteration to Wandb experiment.

Return type:

Tuple[bool, Metrics]

on_train_end(*, model)#

Called at the end of training.

Return type:

None