#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import argparse
from functools import partial
from typing import (
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
from utils import logger
from utils.import_utils import import_modules_from_folder
RegistryItem = TypeVar("RegistryItem", bound=Callable)
[docs]class Registry:
"""
A key/object registry class.
This class is used in CVNets to do Dependency Injection in configs,
so when you write "resnet" in a config, it knows which module to load.
You can potentially provide a `base_class` to ensures that all items
in the registry are of type `base_class`.
Registry also allows for passing arguments to a registered item:
For example: "top1" -> "top1(pred=logits)"
Usage:
>>> my_registry = Registry("registry_name")
>>> @my_registry.register("awesome_class_or_func")
... def my_awesome_class_or_func():
... pass
>>> assert "awesome_class_or_func" in my_registry
It allows for vanilla key/object definition as well as functional argument injection:
>>> reg = Registry("registry_name")
>>> reg.register("awesome_dict")(dict)
>>> reg["awesome_dict(name=hello, type=fifo)]()
{'name': 'hello', 'type': 'fifo'}
"""
[docs] def __init__(
self,
registry_name: str,
base_class: Optional[type] = None,
separator: Optional[str] = ":",
lazy_load_dirs: Optional[List[str]] = None,
internal_dirs: Sequence[str] = (),
) -> None:
"""
Args:
registry_name: registry name, used for debugging and error messages
base_class: If provided, will ensure that all items inside the registry
are of type `base_class`.
separator: Separator between name and type in `register` function.
lazy_load_dirs: If provided, will load all directories under these
directories when inspecting for the modules of the registry.
"""
self.registry_name = registry_name
self.base_class = base_class
self.registry = {}
# For debugging purposes we want to throw a warning if someone accesses
# arguments before registering all items.
self.arguments_accessed = False
self.separator = separator
# Lazy loading to get rid of possible cicular dependencies
self._modules_loaded = False
self._lazy_load_dirs = lazy_load_dirs
self.internal_dirs = internal_dirs
if self._lazy_load_dirs is None:
self._lazy_load_dirs = []
def _load_all(self) -> None:
"""
This function allows for lazily loading modules from pre-specified directories.
The main reason for its existence is to prevent circular imports.
This function should be called before any "pull/get" kind of action from
Registry to make sure it has loaded all registered models, which is pretty
much any operation except for "register".
If self._modules_loaded is not True, it will load all modules under
self._lazy_load_dirs.
"""
if not self._modules_loaded:
self._modules_loaded = True
for dir_name in self._lazy_load_dirs:
import_modules_from_folder(dir_name, extra_roots=self.internal_dirs)
[docs] def items(self) -> List[Tuple[str, RegistryItem]]:
self._load_all()
return self.registry.items()
[docs] def keys(self) -> List[str]:
self._load_all()
return self.registry.keys()
def __iter__(self) -> Iterable[str]:
self._load_all()
return iter(self.registry)
def __getitem__(self, key: Union[Tuple[str, str], str]) -> RegistryItem:
self._load_all()
type_ = None
if isinstance(key, Tuple) and len(key) == 2:
key, type_ = key
assert isinstance(
key, str
), f"Key should be an instance of string. Got {type(key)}"
name, params = self.parse_key(key)
if type_:
name = f"{type_}{self.separator}{name}"
if name not in self.registry:
registry_keys = list(self.registry.keys())
temp_str = (
f"\n{name} not yet supported in {self.registry_name} registry."
f"\nSupported values are:"
)
for i, supp_val in enumerate(registry_keys):
temp_str += f"\n\t {i}: {supp_val}"
logger.error(temp_str + "\n")
reg_item = self.registry[name]
if params:
reg_item = partial(reg_item, **params)
return reg_item
def __contains__(self, key: str) -> bool:
self._load_all()
name, _ = self.parse_key(key)
return name in self.registry
[docs] def register(self, name: str, type: str = "") -> Callable:
if type:
name = "{}{}{}".format(type, self.separator, name)
if self.arguments_accessed:
# TODO: do we really want an error here?
logger.error(
f"Found item `{name}` being registered after all_item_arguments"
f" was called for `{self.registry_name}` registry."
)
def register_with_name(item: RegistryItem) -> RegistryItem:
if name in self.registry:
raise ValueError(
"Cannot register duplicate {} ({})".format(self.registry_name, name)
)
if self.base_class and not issubclass(item, self.base_class):
raise ValueError(
"{} class ({}: {}) must extend {}".format(
self.registry_name, name, item.__name__, self.base_class
)
)
self.registry[name] = item
return item
return register_with_name
[docs] def all_arguments(self, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""
Iterates through all items and fetches their arguments.
Note: make sure that all items are already registered before calling this method.
"""
self._load_all()
self.arguments_accessed = True
for _, item in self.items():
parser = item.add_arguments(parser)
return parser
[docs] def parse_key(self, key: str) -> Tuple[str, Dict[str, str]]:
"""
Parses `key` which can contain arguments in the form of:
<key_name>(arg1=value1, arg2=value2, ...)
Returns:
Tuple: (base_name: str, parameters: dict)
"""
name = key.split("(")[0]
params = {}
if "(" in key:
params_str = key.split("(")[1].split(")")[0]
try:
params = dict(
[
[x.strip() for x in arg.split("=")]
for arg in params_str.split(",")
]
)
except Exception as e:
logger.error(
"Could not correctly parse key parameters `{}` for registry {}."
" Please make sure to key parameters have the format:"
" <key_name>(arg1=value1, arg2=value2, ...)".format(
key, self.registry_name
)
)
raise e
return name, params