Source code for data.loader.dataloader

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

from typing import List, Optional, Union

from torch.utils.data import DataLoader

from data.datasets.dataset_base import BaseDataset
from data.sampler import Sampler


[docs]class CVNetsDataLoader(DataLoader): """This class extends PyTorch's Dataloader"""
[docs] def __init__( self, dataset: BaseDataset, batch_size: int, batch_sampler: Union[Sampler], num_workers: Optional[int] = 1, pin_memory: Optional[bool] = False, persistent_workers: Optional[bool] = False, collate_fn: Optional = None, prefetch_factor: Optional[int] = 2, *args, **kwargs ): super(CVNetsDataLoader, self).__init__( dataset=dataset, batch_size=batch_size, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, collate_fn=collate_fn, prefetch_factor=prefetch_factor, *args, **kwargs )
[docs] def update_indices(self, new_indices: List, *args, **kwargs): """Update indices in the dataset class""" if hasattr(self.batch_sampler, "img_indices") and hasattr( self.batch_sampler, "update_indices" ): self.batch_sampler.update_indices(new_indices)
[docs] def samples_in_dataset(self): """Number of samples in the dataset""" return len(self)
[docs] def get_sample_indices(self) -> List: """Sample IDs""" return self.batch_sampler.img_indices