Source code for colibri.optax_optimizer

"""
colibri.optax_optimizer.py

Module contains functions for optax gradient descent optimisation.

"""

import optax
from flax.training.early_stopping import EarlyStopping

import logging

log = logging.getLogger(__name__)


[docs] def optimizer_provider( optimizer_settings, ) -> optax._src.base.GradientTransformationExtraArgs: """ Define the optimizer. Parameters ---------- optimizer_settings : dict, default = {} Dictionary containing the optimizer settings. Returns ------- optax._src.base.GradientTransformationExtraArgs Optax optimizer. """ optimizer = optimizer_settings["optimizer"] log.info(f"Using {optimizer} optimizer.") opt = getattr(optax, optimizer) optimizer_hyperparams = optimizer_settings["optimizer_hyperparams"] log.info(f"Optimizer hyperparameters: {optimizer_hyperparams}.") if optimizer_settings["clipnorm"] is not None: clipnorm = optimizer_settings["clipnorm"] log.info(f"Using gradient clipping with norm {clipnorm}.") return optax.chain( optax.clip_by_global_norm(clipnorm), opt(**optimizer_hyperparams), ) else: return opt(**optimizer_hyperparams)
[docs] def early_stopper( min_delta=1e-5, patience=20, max_epochs=1000, mc_validation_fraction=0.2 ): """ Define the early stopping criteria. If mc_validation_fraction is zero then patience is the same as max_epochs. """ if not mc_validation_fraction: log.warning( "No validation data provided, patience of early stopping set to max_epochs." ) return EarlyStopping(min_delta=min_delta, patience=max_epochs) return EarlyStopping(min_delta=min_delta, patience=patience)