Hyperparameters#
- typing.ParameterType#
- class pfl.hyperparam.base.HyperParam#
Base class for defining adaptive hyperparameters. An adaptive hyperparameter can be used as a substitute for a static parameter where permitted (mostly in the configs and algorithms, see respective type signatures).
Make the subclass also a postprocessor or callback to access hooks where adapting the hyperparameters can take place.
- Example:
This is an example of an adaptive hyperparameter (cohort size), which increases by a factor of 2 each central iteration.
class MyCohortSize(HyperParam, TrainingProcessCallback): def __init__(self, initial): self._value = initial def after_central_iteration(self, aggregate_metrics, model, central_iteration): self._value *= 2 def value(self): return self._value
- abstract value()#
The current state (inner value) of the hyperparameter.
- Return type:
TypeVar
(ParameterType
)
- pfl.hyperparam.base.get_param_value(parameter)#
If input is a
HyperParam
, extract its current value, otherwise return identity.- Example:
>>> get_param_value(1.0) 1.0 >>> get_param_value(MyCustomParam(initial_value=2.0)) 2.0
- class pfl.hyperparam.base.HyperParams#
Base class for dataclasses that store parameters for model/training.
- static_clone(**kwargs)#
Returns a static clone of hyperparameters where each parameter has its current value (including adaptive parameters). This is used to access parameters in the algorithms (e.g.
FederatedNNAlgorithm
).- Return type:
TypeVar
(HyperParamsType
, bound= HyperParams)
- get(key)#
Get the current static value of a hyperparameter (which is a property of the dataclass). I.e. in the case of the hyperparam being a
HyperParam
, return the inner value state.
- class pfl.hyperparam.base.AlgorithmHyperParams#
Base class for additional parameters to pass to algorithms. By default, this base class has no parameters, but subclasses purposed for certain federated algorithms will have additional parameters.
- class pfl.hyperparam.base.ModelHyperParams#
A base configuration for training models. By default, this base class has no parameters, but subclasses purposed for certain models will have additional parameters.
- class pfl.hyperparam.base.NNEvalHyperParams(local_batch_size)#
Config to use for evaluating any neural network with an algorithm that involves SGD.
- Parameters:
local_batch_size (
Union
[HyperParam
[int
],int
,None
]) – The batch size for evaluating locally on device. If None, defaults to no batching (full-batch evaluation).
- class pfl.hyperparam.base.NNTrainHyperParams(local_batch_size, local_num_epochs, local_learning_rate, local_max_grad_norm=None, local_num_steps=None, grad_accumulation_steps=1)#
Config to use for training any neural network with an algorithm that involves SGD.
- Parameters:
local_num_epochs (
Union
[HyperParam
[int
],int
,None
]) – The number of epochs of training applied on the device. If this is set,local_num_steps
must beNone
.local_learning_rate (
Union
[HyperParam
[float
],float
]) – The learning rate applied on the device.local_batch_size (
Union
[HyperParam
[int
],int
,None
]) – The batch size for training locally on device. If None, defaults to the entire dataset, which means one local iteration is one epoch.local_max_grad_norm (
Union
[HyperParam
[float
],float
,None
]) – Maximum L2 norm for gradient update in each local optimization step. Local gradients on device will be clipped if their L2 norm is larger than local_max_grad_norm. If None, no clipping is performed.local_num_steps (
Union
[HyperParam
[int
],int
,None
]) – Number of gradient steps during local training. If this is set,local_num_epochs
must beNone
. Stops beforelocal_num_steps
if iterated through dataset. This can be useful if the user dataset is very large and training less than an epoch is appropriate.grad_accumulation_steps (
int
) – Number of steps to accumulate gradients before apply a local optimizer update. The effective batch size islocal_batch_size
multiplied by this number. This is useful to simulate a larger local batch size when memory is limited. Currently only supported for PyTorch.