Source code for options.parse_args

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import json
import typing
from typing import Any


[docs]def parse_validation_metric_names(opts): """ This function contains common command-line parsing logic for validation metrics """ metric_names = getattr(opts, "stats.val", ["loss"]) if isinstance(metric_names, str): metric_names = [metric_names] assert isinstance( metric_names, list ), "Type of metric names should be list. Got: {}".format(type(metric_names)) if "loss" not in metric_names: metric_names.append("loss") ckpt_metric_str = getattr(opts, "stats.checkpoint_metric", "loss") ckpt_metric_arr = ckpt_metric_str.split(".") ckpt_metric = ckpt_metric_arr[0] if len(ckpt_metric_arr) == 1: ckpt_submetric_name = None else: ckpt_submetric_name = ckpt_metric_arr[-1] ckpt_metric = ckpt_metric ckpt_submetric = ckpt_submetric_name if ckpt_metric is None: # if checkpoint metric is not specified, then use loss ckpt_metric = "loss" assert ( ckpt_metric in metric_names ), "Checkpoint metric should be part of metric names. Metric names: {}, Checkpoint metric: {}".format( metric_names, ckpt_metric ) ckpt_metric = ckpt_metric.lower() return metric_names, ckpt_metric, ckpt_submetric
[docs]class JsonValidator:
[docs] def __init__(self, expected_type: type): """ JsonValidator(T) is function (s)->x that parses json string s into python value x, where x is of type T. Example Usage: >>> from typing import Union, List >>> import argparse >>> parser = argparse.ArgumentParser() >>> parser.add_argument("--x", type=JsonValidator(Union[int, List[float]])) >>> assert parser.parse_args(["--x=123"]).x == 123 >>> assert parser.parse_args(["--x=[1, 2]"]).x == [1., 2.] """ self.expected_type = expected_type
@classmethod def _validate_and_cast(cls, json_value: Any, expected_type: Any): type_cls = ( typing.get_origin(expected_type) or expected_type ) # typing.get_origin() returns None for non-generic types like `Any` and `int` type_args = typing.get_args(expected_type) if type_cls is typing.Any: return json_value if type_cls is float and isinstance(json_value, (int, float)): return float(json_value) elif type_cls in (int, str, bool) and isinstance(json_value, type_cls): return json_value elif type_cls is None and json_value is None: return None elif type_cls is typing.Union: for arg in type_args: try: return cls._validate_and_cast(json_value, arg) except TypeError: continue elif type_cls is dict and isinstance(json_value, dict): if not type_args: type_args = (Any, Any) type_key, type_value = type_args return { cls._validate_and_cast(key, type_key): cls._validate_and_cast( value, type_value ) for key, value in json_value.items() } elif type_cls is list and isinstance(json_value, list): if not type_args: type_args = [Any] return [cls._validate_and_cast(x, type_args[0]) for x in json_value] elif ( type_cls is tuple and isinstance(json_value, list) and (type_args is None or len(type_args) == len(json_value)) ): if type_args is None: type_args = [Any] * len(json_value) return tuple( type_cls( cls._validate_and_cast(item, type_arg) for item, type_arg in zip(json_value, type_args) ) ) raise TypeError( f"Cannot cast {json_value} with type {type(json_value)} to {expected_type}" ) def __call__(self, str_value: str) -> Any: try: value = json.loads(str_value) except json.JSONDecodeError: raise TypeError(f"Cannot parse json value '{str_value}' for {self}") return self._validate_and_cast(value, self.expected_type) def __repr__(self): return f"JSON[{self.expected_type}]"