| | |
| | |
| | |
| | |
| | |
| |
|
| | from ..path import mkdir_or_exist |
| | from ..version_utils import digit_version |
| | from .parrots_wrapper import TORCH_VERSION |
| |
|
| | if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version( |
| | '1.7.0'): |
| | |
| | import os |
| | import sys |
| | import warnings |
| | import zipfile |
| | from urllib.parse import urlparse |
| |
|
| | import torch |
| | from torch.hub import HASH_REGEX, _get_torch_home, download_url_to_file |
| |
|
| | |
| | |
| | |
| | |
| | def _is_legacy_zip_format(filename): |
| | if zipfile.is_zipfile(filename): |
| | infolist = zipfile.ZipFile(filename).infolist() |
| | return len(infolist) == 1 and not infolist[0].is_dir() |
| | return False |
| |
|
| | def _legacy_zip_load(filename, model_dir, map_location): |
| | warnings.warn( |
| | 'Falling back to the old format < 1.6. This support will' |
| | ' be deprecated in favor of default zipfile format ' |
| | 'introduced in 1.6. Please redo torch.save() to save it ' |
| | 'in the new zipfile format.', DeprecationWarning) |
| | |
| | |
| | |
| | |
| | with zipfile.ZipFile(filename) as f: |
| | members = f.infolist() |
| | if len(members) != 1: |
| | raise RuntimeError( |
| | 'Only one file(not dir) is allowed in the zipfile') |
| | f.extractall(model_dir) |
| | extraced_name = members[0].filename |
| | extracted_file = os.path.join(model_dir, extraced_name) |
| | return torch.load(extracted_file, map_location=map_location) |
| |
|
| | def load_url(url, |
| | model_dir=None, |
| | map_location=None, |
| | progress=True, |
| | check_hash=False, |
| | file_name=None): |
| | r"""Loads the Torch serialized object at the given URL. |
| | If downloaded file is a zip file, it will be automatically decompressed |
| | If the object is already present in `model_dir`, it's deserialized and |
| | returned. |
| | The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where |
| | ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`. |
| | Args: |
| | url (str): URL of the object to download |
| | model_dir (str, optional): directory in which to save the object |
| | map_location (optional): a function or a dict specifying how to |
| | remap storage locations (see torch.load) |
| | progress (bool, optional): whether or not to display a progress bar |
| | to stderr. Defaults to True |
| | check_hash(bool, optional): If True, the filename part of the URL |
| | should follow the naming convention ``filename-<sha256>.ext`` |
| | where ``<sha256>`` is the first eight or more digits of the |
| | SHA256 hash of the contents of the file. The hash is used to |
| | ensure unique names and to verify the contents of the file. |
| | Defaults to False |
| | file_name (str, optional): name for the downloaded file. Filename |
| | from ``url`` will be used if not set. Defaults to None. |
| | Example: |
| | >>> url = ('https://s3.amazonaws.com/pytorch/models/resnet18-5c106' |
| | ... 'cde.pth') |
| | >>> state_dict = torch.hub.load_state_dict_from_url(url) |
| | """ |
| | |
| | if os.getenv('TORCH_MODEL_ZOO'): |
| | warnings.warn( |
| | 'TORCH_MODEL_ZOO is deprecated, please use env ' |
| | 'TORCH_HOME instead', DeprecationWarning) |
| |
|
| | if model_dir is None: |
| | torch_home = _get_torch_home() |
| | model_dir = os.path.join(torch_home, 'checkpoints') |
| |
|
| | mkdir_or_exist(model_dir) |
| |
|
| | parts = urlparse(url) |
| | filename = os.path.basename(parts.path) |
| | if file_name is not None: |
| | filename = file_name |
| | cached_file = os.path.join(model_dir, filename) |
| | if not os.path.exists(cached_file): |
| | sys.stderr.write('Downloading: "{}" to {}\n'.format( |
| | url, cached_file)) |
| | hash_prefix = None |
| | if check_hash: |
| | r = HASH_REGEX.search(filename) |
| | hash_prefix = r.group(1) if r else None |
| | download_url_to_file( |
| | url, cached_file, hash_prefix, progress=progress) |
| |
|
| | if _is_legacy_zip_format(cached_file): |
| | return _legacy_zip_load(cached_file, model_dir, map_location) |
| |
|
| | try: |
| | return torch.load(cached_file, map_location=map_location) |
| | except RuntimeError as error: |
| | if digit_version(TORCH_VERSION) < digit_version('1.5.0'): |
| | warnings.warn( |
| | f'If the error is the same as "{cached_file} is a zip ' |
| | 'archive (did you mean to use torch.jit.load()?)", you can' |
| | ' upgrade your torch to 1.5.0 or higher (current torch ' |
| | f'version is {TORCH_VERSION}). The error was raised ' |
| | ' because the checkpoint was saved in torch>=1.6.0 but ' |
| | 'loaded in torch<1.5.') |
| | raise error |
| | else: |
| | from torch.utils.model_zoo import load_url |
| |
|