Postprocessor#

class pfl.postprocessor.base.Postprocessor#

A postprocessor defines an interface for features to interact with statistics after local training and after central aggregation. Many different categories of features operate in this way, e.g. weighting, adaptive hyperparameter methods, sparsification, privacy mechanisms.

postprocess_one_user(*, stats, user_context)#

Do any postprocessing of client’s statistics before it is communicated back to the server.

Parameters:
  • stats (TypeVar(StatisticsType, bound= TrainingStatistics)) – Statistics returned from the local training procedure of this user.

  • user_context (UserContext) – Additional information about the current user.

Return type:

Tuple[TypeVar(StatisticsType, bound= TrainingStatistics), Metrics]

Returns:

A tuple (transformed_stats, metrics), where transformed_stats is stats after it is processed by the postprocessor, and metrics is any new metrics to track. Default implementation does nothing.

postprocess_server(*, stats, central_context, aggregate_metrics)#

Do any postprocessing of the aggregated statistics object after central aggregation.

Parameters:
  • stats (TypeVar(StatisticsType, bound= TrainingStatistics)) – The aggregated statistics.

  • central_context (CentralContext) – Information about aggregation and other useful server-side properties.

Return type:

Tuple[TypeVar(StatisticsType, bound= TrainingStatistics), Metrics]

Returns:

A tuple (transformed_stats, metrics), where transformed_stats is stats after it is processed by the postprocessor, and metrics is any new metrics to track. Default implementation does nothing.

postprocess_server_live(*, stats, central_context, aggregate_metrics)#

Just like postprocess_server, but for live training. Default implementation is to call postprocess_server. Only override this in certain circumstances when you want different behaviour for live training, e.g. central DP.

Return type:

Tuple[TypeVar(StatisticsType, bound= TrainingStatistics), Metrics]

class pfl.postprocessor.metrics.SummaryMetrics(metric_name, quantiles, min_bound, max_bound, num_bins=100, frequency=1, stddev=False)#

Given a name of an existing metric, accumulate a histogram of per-user values of that metric, then calculate aggregate statistics on the histogram.

This is useful for metrics that cannot be processed using (weighted) average and for providing more details on metric distribution. For example, this class can be used to get an accuracy below which 1% of users fall (i.e., 1st percentile).

Parameters:
  • metric_name (MetricName) – The name of the metric to accumulate histogram for. Must previously have been generated by e.g. model evaluation.

  • quantiles (List[float]) – A list of quantiles to calculate on the histogram. E.g. quantile 0.1 will return the histogram bound such that 10% of users fall below this bound based on the aggregated histogram (using linear interpolation). That is why to accurately compute quantiles, there must be sufficiently many histogram bins (typically well over 100).

  • min_bound (float) – Minimum bound for histogram. Values less than bound will be ignored.

  • max_bound (float) – Maximum bound for histogram. Values larger than bound will be ignored.

  • num_bins (int) –

    Number of bins for the histogram. The number of bins affects the fidelity of the quantiles and other summary statistics computed on the specified source metric. There is no point in picking a smaller value <100 for lower fidelity unless it is only for visualization purposes (the histograms generated by this postprocessor can be visualized with e.g. TensorBoardCallback).

    TODO: Add support for non-equally-spaced bins upon request.

    rdar://103595866 (histogram with non-equally-spaced bins)

  • frequency (int) – Compute summary of metric every frequency central iterations. This frequency should be the same as the frequency at which the metric being summarised is generated.

  • stddev (bool) – Calculate standard deviation of the histogram.

Example:

This example shows summary metrics on accuracy. Since we know accuracy is bounded [0,1], we set min_bound and max_bound to those values. With num_bins=1000, the range of each bin will be 0.1% (a value of 0.001) of the total range specified by the bounds.

SummaryMetrics(
    metric_name=TrainMetricName('accuracy',
                                Population.TRAIN,
                                after_training=False),
    quantiles=[0.01, 0.02, 0.98, 0.99],
    min_bound=0.0,
    max_bound=1.0,
    num_bins=1000)
Example:

This example show summary metrics on loss. We know cross-entropy loss is bounded by [0,inf), so we need to pick a max_bound that suits the current experiment. The loss on iteration 0 for Reddit dataset was 9.3, so a good pick is slightly above 9.3 in this case.

SummaryMetrics(
    metric_name=TrainMetricName('loss',
                                Population.TRAIN,
                                after_training=False),
    quantiles=[0.01, 0.02, 0.98, 0.99],
    min_bound=0.0,
    max_bound=10.0,
    num_bins=1000)
postprocess_one_user(*, stats, user_context)#

Creates histogram metric of a single user’s metrics.

Code continues without action if metric to be used for histogram is not in user metrics passed via the user context.

Parameters:
  • stats (TrainingStatistics) – Statistics returned from the local training procedure of this user. It is just passed through to output.

  • user_context (UserContext) – Additional information about the current user. This includes the values of metrics to use in summary statistics.

Return type:

Tuple[TrainingStatistics, Metrics]

Returns:

A tuple (stats, metrics), where stats is same as the input, and metrics contains the histogram metric for one user.

postprocess_server(*, stats, central_context, aggregate_metrics)#

Derive aggregate statistics from the aggregated histogram.

Executes every freqency number of central iterations, and if the metric to be summarised is in aggregate_metrics.

Parameters:
  • stats (TrainingStatistics) – The aggregated statistics. Not used.

  • central_context (CentralContext) – Information about aggregation and other useful server-side properties. Should contain aggregated histogram metric.

Return type:

Tuple[TrainingStatistics, Metrics]

Returns:

A tuple (stats, metrics), where stats is same as input, and metrics contains aggregate statistics about the histogram of interest.