# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
import argparse
from loss_fn import LOSS_REGISTRY, BaseCriteria
[docs]@LOSS_REGISTRY.register(name="__base__", type="segmentation")
class BaseSegmentationCriteria(BaseCriteria):
"""Base class for defining segmentation loss functions. Sub-classes must implement forward function.
Args:
opts: command line arguments
"""
[docs] def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None:
super().__init__(opts, *args, **kwargs)
[docs] @classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
if cls != BaseSegmentationCriteria:
# Don't re-register arguments in subclasses that don't override `add_arguments()`.
return parser
group = parser.add_argument_group(cls.__name__)
group.add_argument(
"--loss.segmentation.name",
type=str,
default=None,
help="Name of the loss function. Defaults to None.",
)
return parser