Source code for utils.download_utils_base

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import argparse
import copy
import os
import time
from typing import Optional

import requests

from common import TMP_CACHE_LOC, if_test_env
from utils import logger
from utils.ddp_utils import dist_barrier, is_start_rank_node


[docs]def get_basic_local_path( opts: argparse.Namespace, path: str, cache_loc: str = TMP_CACHE_LOC, force_delete: Optional[bool] = None, use_start_rank: bool = True, sync_ranks: bool = True, *args, **kwargs ) -> str: """ If File name is a URL, download to TMP_CACHE_LOC and then return the local path. Otherwise, don't do anything """ if force_delete is None: # An alternative approach is to move this logic to the argument's default value in the function signature: # def get_basic_local_path(..., force_delete = if_test_env(False, otherwise=True), ...): # But that won't always worker because pytest may set PYTEST_CURRENT_TEST env set loading this module. force_delete = if_test_env(False, otherwise=True) if ( path.find("s3://") > -1 or path.find("http://") > -1 or path.find("https://") > -1 ): url_path = copy.deepcopy(path) ckpt_name = path.split(os.sep)[-1] local_path = "{}/{}".format(cache_loc, ckpt_name) local_path = str(local_path).strip() if os.path.isfile(local_path) and force_delete: # If file exists, remove it and then download again # This is important because if we are downloading from bolt tasks, then checkpoint names are the same if use_start_rank: # remove files from start rank only if is_start_rank_node(opts): os.remove(local_path) else: while not os.path.isfile(local_path): time.sleep(if_test_env(0, otherwise=1)) continue else: # All ranks in DDP can remove the files os.remove(local_path) if not os.path.isfile(local_path): if not use_start_rank or is_start_rank_node(opts): _download_file(url_path, local_path) else: while os.path.isfile(local_path): # download file on start rank and let other ranks keep waiting till file is downloaded # in DDP, download file in all ranks time.sleep(if_test_env(0, otherwise=1)) continue if getattr(opts, "ddp.use_distributed", False) and sync_ranks: # synchronize between processes dist_barrier() return local_path return os.path.expanduser(path)
def _download_file(url_path: str, dest_loc: str) -> None: """ Helper function to download a file with proxy (used when file fails) """ response = requests.get(url_path, stream=True) if response.status_code == 403: # try with the HTTP/HTTPS proxy from ENV proxies = { "https": os.environ.get("HTTPS_PROXY", None), "http": os.environ.get("HTTP_PROXY", None), } response = requests.get(url_path, stream=True, proxies=proxies) if response.status_code == 200: with open(dest_loc, "wb") as f: f.write(response.raw.read()) else: logger.error("Unable to download file {}".format(url_path))