| import os |
| import time |
| import shutil |
| import logging |
| import subprocess |
| import os.path as op |
| from typing import List |
| from collections import OrderedDict |
|
|
| import torch.distributed as distributed |
|
|
| logger = logging.getLogger(__name__) |
|
|
| DEFAULT_AZCOPY_PATH = 'azcopy/azcopy' |
|
|
|
|
| def disk_usage(path: str) -> float: |
| stat = shutil.disk_usage(path) |
| return stat.used / stat.total |
|
|
|
|
| def is_download_successful(stdout: str) -> bool: |
| for line in stdout.split('\n'): |
| if line == "Number of Transfers Failed: 0": |
| return True |
| logger.info("Azcopy message:\n %s" % stdout) |
| return False |
|
|
|
|
| def ensure_directory(path): |
| """Check existence of the given directory path. If not, create a new directory. |
| |
| Args: |
| path (str): path of a given directory. |
| """ |
| if path == '' or path == '.': |
| return |
| if path is not None and len(path) > 0: |
| assert not op.isfile(path), '{} is a file'.format(path) |
| if not op.exists(path) and not op.islink(path): |
| os.makedirs(path, exist_ok=True) |
| |
| assert op.isdir(op.abspath(path)), path |
|
|
|
|
| class LRU(OrderedDict): |
| def __init__(self, maxsize=3): |
| self.maxsize = maxsize |
|
|
| def __getitem__(self, key): |
| value = super().__getitem__(key) |
| self.move_to_end(key) |
| return value |
|
|
| def __setitem__(self, key, value): |
| if key in self: |
| if self[key] is not None: |
| self[key].close() |
| self.move_to_end(key) |
|
|
| logger.debug('=> Cache {}'.format(key)) |
| super().__setitem__(key, value) |
|
|
| if len(self) > self.maxsize: |
| oldest = next(iter(self)) |
| if self[oldest] is not None: |
| self[oldest].close() |
| logger.debug('=> Purged {}'.format(oldest)) |
| del self[oldest] |
|
|
|
|
| class BlobStorage(OrderedDict): |
| """ Pseudo Blob Storage manager |
| |
| The registered blobs are maintained in a LRU cache. |
| Limit size, evicting the least recently looked-up key when full. |
| https://docs.python.org/3/library/collections.html#collections.OrderedDict |
| |
| Input argument: |
| sas_token (str): path to SAS token. |
| """ |
| def __init__(self, |
| is_train: bool, |
| sas_token_path: str = None, |
| azcopy_path: str = None, |
| *args, **kwds): |
| super().__init__(*args, **kwds) |
| self.maxsize = 2 if is_train else 10 |
| self.is_train = is_train |
|
|
| if sas_token_path: |
| self.sas_token = BlobStorage.read_sas_token(sas_token_path) |
| self.base_url = self.sas_token[:self.sas_token.index("?")] |
| self.query_string = self.sas_token[self.sas_token.index("?"):] |
| self.container = BlobStorage.extract_container(self.sas_token) |
| else: |
| self.sas_token = None |
| self.base_url = None |
| self.query_string = None |
| self.container = None |
|
|
| logger.debug( |
| f"=> [BlobStorage] Base url: {self.base_url}" |
| f"=> [BlobStorage] Query string: {self.query_string}" |
| f"=> [BlobStorage] Container name: {self.container}" |
| ) |
|
|
| self.azcopy_path = azcopy_path if azcopy_path else DEFAULT_AZCOPY_PATH |
| self._cached_files = LRU(3) |
|
|
| def __getitem__(self, key): |
| value = super().__getitem__(key) |
| self.move_to_end(key) |
| return value |
|
|
| def __setitem__(self, key, value): |
| if key in self: |
| self.move_to_end(key) |
| super().__setitem__(key, value) |
| |
| |
| if len(self) > self.maxsize: |
| oldest = next(iter(self)) |
| del self[oldest] |
|
|
| @staticmethod |
| def read_sas_token(path: str) -> str: |
| with open(path, 'r') as f: |
| token = f.readline().strip() |
| return token |
|
|
| @staticmethod |
| def extract_container(token: str) -> str: |
| """ |
| Input argument: |
| token (str): the full URI of Shared Access Signature (SAS) in the following format. |
| https://[storage_account].blob.core.windows.net/[container_name][SAS_token] |
| """ |
| return os.path.basename(token.split('?')[0]) |
|
|
| def _convert_to_blob_url(self, local_path: str): |
| return self.base_url + local_path.split("azcopy")[1] + self.query_string |
|
|
| def _convert_to_blob_folder_url(self, local_path: str): |
| return self.base_url + local_path.split("azcopy")[1] + "/*" + self.query_string |
|
|
| def fetch_blob(self, local_path: str) -> None: |
| if op.exists(local_path): |
| logger.info('=> Try to open {}'.format(local_path)) |
| fp = open(local_path, 'r') |
| self._cached_files[local_path] = fp |
| logger.debug("=> %s downloaded. Skip." % local_path) |
| return |
| blob_url = self._convert_to_blob_url(local_path) |
| rank = '0' if 'RANK' not in os.environ else os.environ['RANK'] |
| cmd = [self.azcopy_path, "copy", blob_url, local_path + rank] |
| curr_usage = disk_usage('/') |
| logger.info( |
| "=> Downloading %s with azcopy ... (disk usage: %.2f%%)" |
| % (local_path, curr_usage * 100) |
| ) |
| proc = subprocess.run(cmd, stdout=subprocess.PIPE) |
| while not is_download_successful(proc.stdout.decode()): |
| logger.info("=> Azcopy failed to download {}. Retrying ...".format(blob_url)) |
| proc = subprocess.run(cmd, stdout=subprocess.PIPE) |
| if not op.exists(local_path): |
| os.rename(local_path + rank, local_path) |
| else: |
| os.remove(local_path + rank) |
| logger.info( |
| "=> Downloaded %s with azcopy ... (disk usage: %.2f%% => %.2f%%)" % |
| (local_path, curr_usage * 100, disk_usage('/') * 100) |
| ) |
|
|
| def fetch_blob_folder(self, local_path: str, azcopy_args: list=[]) -> None: |
| blob_url = self._convert_to_blob_folder_url(local_path) |
| cmd = [self.azcopy_path, "copy", blob_url, local_path] + azcopy_args |
| curr_usage = disk_usage('/') |
| logger.info( |
| "=> Downloading %s with azcopy args %s ... (disk usage: %.2f%%)" |
| % (local_path, ' '.join(azcopy_args), curr_usage * 100) |
| ) |
| proc = subprocess.run(cmd, stdout=subprocess.PIPE) |
| while not is_download_successful(proc.stdout.decode()): |
| logger.info("=> Azcopy failed to download {} with args {}. Retrying ...".format(blob_url, ' '.join(azcopy_args))) |
| proc = subprocess.run(cmd, stdout=subprocess.PIPE) |
| logger.info( |
| "=> Downloaded %s with azcopy args %s ... (disk usage: %.2f%% => %.2f%%)" % |
| (local_path, ' '.join(azcopy_args), curr_usage * 100, disk_usage('/') * 100) |
| ) |
|
|
| def register_local_tsv_paths(self, local_paths: List[str]) -> List[str]: |
| if self.sas_token: |
| tsv_paths_new = [] |
| lineidx_paths = set() |
| linelist_paths = set() |
| for path in local_paths: |
| tsv_path_az = path.replace(self.container, 'azcopy') |
| tsv_paths_new.append(tsv_path_az) |
| logger.debug("=> Registering {}".format(tsv_path_az)) |
|
|
| if not self.is_train: |
| logger.info('=> Downloading {}...'.format(tsv_path_az)) |
| self.fetch_blob(tsv_path_az) |
| logger.info('=> Downloaded {}'.format(tsv_path_az)) |
|
|
| lineidx = op.splitext(path)[0] + '.lineidx' |
| lineidx_ = lineidx.replace(self.container, 'azcopy') |
| if self.is_train: |
| if not op.isfile(lineidx_) and op.dirname(lineidx_) not in lineidx_paths: |
| lineidx_paths.add(op.dirname(lineidx_)) |
| else: |
| if not op.isfile(lineidx_): |
| ensure_directory(op.dirname(lineidx_)) |
| self.fetch_blob(lineidx_) |
|
|
| linelist = op.splitext(path)[0] + '.linelist' |
| linelist_ = linelist.replace(self.container, 'azcopy') |
| |
| if self.is_train: |
| if op.isfile(linelist) and not op.isfile(linelist_) and op.dirname(linelist_) not in linelist_paths: |
| linelist_paths.add(op.dirname(linelist_)) |
| else: |
| if op.isfile(linelist) and not op.isfile(linelist_): |
| ensure_directory(op.dirname(linelist_)) |
| self.fetch_blob(linelist_) |
|
|
| if self.is_train: |
| for path in lineidx_paths: |
| self.fetch_blob_folder(path, azcopy_args=['--include-pattern', '*.lineidx']) |
|
|
| for path in linelist_paths: |
| self.fetch_blob_folder(path, azcopy_args=['--include-pattern', '*.linelist']) |
|
|
| return tsv_paths_new |
| else: |
| return local_paths |
|
|
| def open(self, local_path: str): |
| if self.sas_token and 'azcopy' in local_path: |
| while not op.exists(local_path): |
| time.sleep(1) |
| fid = open(local_path, 'r') |
| return fid |
|
|