Source code for loss_fn

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import argparse
from typing import Optional

from loss_fn.base_criteria import BaseCriteria
from utils import logger
from utils.registry import Registry

# Registry for loss functions.
LOSS_REGISTRY = Registry(
    registry_name="loss_functions",
    base_class=BaseCriteria,
    lazy_load_dirs=["loss_fn"],
    internal_dirs=["internal", "internal/projects/*"],
)


[docs]def build_loss_fn( opts: argparse.Namespace, category: Optional[str] = "", *args, **kwargs ) -> BaseCriteria: """Helper function to build loss function from command-line arguments. Args: opts: command-line arguments category: Optional task category (e.g., classification). Specifying category may be useful for building composite loss functions. See `loss_fns.composite_loss.CompositeLoss.build_composite_loss_fn` function for an example Returns: Loss function module """ if not category: # If category is not specified, then read it from command-line arguments category = getattr(opts, "loss.category") if category is None: logger.error( "Please specify loss name using --loss.category. For composite loss function, see configuration" "example in `loss_fns.composite_loss.CompositeLoss`. Got None" ) # Get the name of loss function for a given category. # Note that loss functions (e.g., NeuralAugmentation) that are not task-specific does not have this # argument defined. In such case, we set the loss function name the same as category if hasattr(opts, f"loss.{category}.name"): loss_fn_name = getattr(opts, f"loss.{category}.name") else: loss_fn_name = category # We registered the base criterion classes for different categories using a special `name` (i.e., `__base__`) # in order to access the arguments defined inside those classes. However, these classes are not supposed to # be used as a loss function. Therefore, we raise an error for such cases if loss_fn_name == "__base__": logger.error("__base__ can't be used as a loss function name. Please check.") loss_fn = LOSS_REGISTRY[loss_fn_name, category](opts, *args, **kwargs) return loss_fn
[docs]def add_loss_fn_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """This method gets a parser object, and for every loss that is registered in the LOSS_REGISTRY adds its arguments to it.""" parser = BaseCriteria.add_arguments(parser=parser) parser = LOSS_REGISTRY.all_arguments(parser) return parser