Source code for data.datasets.utils.common

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import collections
import os
import random
from typing import Any, List, Optional, Tuple, Union

IMG_EXTENSIONS = (
    ".jpg",
    ".jpeg",
    ".png",
    ".ppm",
    ".bmp",
    ".pgm",
    ".tif",
    ".tiff",
    ".webp",
)


[docs]def file_has_valid_image_extension(filename: str) -> bool: return file_has_allowed_extension(filename, IMG_EXTENSIONS)
[docs]def file_has_allowed_extension( filename: str, extensions: Union[str, Tuple[str, ...]] ) -> bool: """Checks if a file has an allowed extension. Args: filename: Path to a file. extensions: A string or a tuple of strings specifying the file extensions. Returns: True if the filename ends with one of given extensions, else False """ return filename.lower().endswith(extensions)
[docs]def get_image_paths(directory: str) -> List[str]: """Returns a list of paths to all image files in the input directory and its subdirectories.""" image_paths = [] for root, _, fnames in sorted(os.walk(directory, topdown=False)): for fname in sorted(fnames): path = os.path.join(root, fname) if file_has_valid_image_extension(path): image_paths.append(path) return image_paths
[docs]def select_random_subset( random_seed: int, num_total_samples: int, num_samples_to_select: Optional[int] = None, percentage_of_samples_to_select: Optional[float] = None, ) -> List[int]: """ Randomly selects a subset of samples. Only one of `num_samples_to_select` and `percentage_of_samples_to_select` should be provided. Selects all the samples if neither of them are provided. Args: random_seed: An integer seed to use for random selection. num_total_samples: Total number of samples in the set that is being subsampled. num_samples_to_select: An optional integer indicating the number of samples to select. percentage_of_samples_to_select: An optional float in the range (0,100] indicating the percentage of samples to select. Returns: A list of (integer) indices of the selected samples. Raises: ValueError if both `num_samples_to_select` and `percentage_of_samples_to_select` are provided. """ if ( num_samples_to_select is not None and percentage_of_samples_to_select is not None ): raise ValueError( "Only one of `num_samples_to_select` and `percentage_of_samples_to_select` should be provided." ) if num_samples_to_select is not None and num_samples_to_select < 1: raise ValueError("`num_samples_to_select` should be greater than 0.") if percentage_of_samples_to_select is not None: if not 0 < percentage_of_samples_to_select <= 100: raise ValueError( "`percentage_of_samples_to_select` should be in the range (0, 100]." ) sample_indices = list(range(num_total_samples)) rng = random.Random(random_seed) rng.shuffle(sample_indices) if num_samples_to_select is None and percentage_of_samples_to_select is None: return sample_indices if num_samples_to_select is None: num_samples_to_select = int( percentage_of_samples_to_select * num_total_samples / 100 ) num_samples_to_select = min(num_samples_to_select, num_total_samples) return sample_indices[:num_samples_to_select]
[docs]def select_samples_by_category( sample_category_labels: List[Any], random_seed: int, num_samples_per_category: Optional[int] = None, percentage_of_samples_per_category: Optional[float] = None, ) -> List[int]: """ Randomly selects a specified number/percentage of samples from each category. Only one of `num_samples_per_category` and `percentage_of_samples_per_category` should be provided. Selects all the samples if neither of them are provided. Args: sample_category_labels: A list of category labels. random_seed: An integer seed to use for random selection. num_samples_per_category: An optional integer indicating the number of samples to select from each category. percentage_of_samples_per_category: An optional float in the range (0, 100] indicating the percentage of samples to select from each category. Returns: A list of (integer) indices of the selected samples. Raises: ValueError if both `num_samples_per_category` and `percentage_of_samples_per_category` are provided. """ if ( num_samples_per_category is not None and percentage_of_samples_per_category is not None ): raise ValueError( "Only one of `num_samples_per_category` and `percentage_of_samples_per_category` should be provided." ) if num_samples_per_category is None and percentage_of_samples_per_category is None: return list(range(len(sample_category_labels))) if num_samples_per_category is not None and num_samples_per_category < 1: raise ValueError("`num_samples_per_category` should be greater than 0.") if percentage_of_samples_per_category is not None: if not 0 < percentage_of_samples_per_category <= 100: raise ValueError( "`percentage_of_samples_per_category` should be in the range (0, 100]." ) category_specific_samples = collections.defaultdict(list) for ind, label in enumerate(sample_category_labels): category_specific_samples[label].append(ind) rng = random.Random(random_seed) selected_sample_indices = [] for label, sample_indices in category_specific_samples.items(): rng.shuffle(sample_indices) if num_samples_per_category: num_samples = num_samples_per_category else: num_samples = int( percentage_of_samples_per_category * len(sample_indices) / 100 ) num_samples = min(num_samples, len(sample_indices)) selected_sample_indices += sample_indices[:num_samples] return selected_sample_indices