#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import argparse
from cvnets.models import MODEL_REGISTRY, BaseAnyNNModel
[docs]@MODEL_REGISTRY.register(name="__base__", type="audio_classification")
class BaseAudioClassification(BaseAnyNNModel):
"""Base class for audio classification.
Args:
opts: Command-line arguments
"""
[docs] def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None:
super().__init__(opts, *args, **kwargs)
[docs] @classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Add model specific arguments"""
if cls != BaseAudioClassification:
# Don't re-register arguments in subclasses that don't override `add_arguments()`.
return parser
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--model.audio-classification.name",
type=str,
default=None,
help="Name of the audio classification model. Defaults to None.",
)
group.add_argument(
"--model.audio-classification.pretrained",
type=str,
default=None,
help="Path of the pretrained backbone. Defaults to None.",
)
return parser