| | |
| | import glob |
| | import os |
| | import os.path as osp |
| | import urllib |
| | import warnings |
| | from typing import Union |
| |
|
| | import torch |
| | from mmengine.config import Config, ConfigDict |
| | from mmengine.logging import print_log |
| | from mmengine.utils import scandir |
| |
|
| | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', |
| | '.tiff', '.webp') |
| |
|
| |
|
| | def find_latest_checkpoint(path, suffix='pth'): |
| | """Find the latest checkpoint from the working directory. |
| | |
| | Args: |
| | path(str): The path to find checkpoints. |
| | suffix(str): File extension. |
| | Defaults to pth. |
| | |
| | Returns: |
| | latest_path(str | None): File path of the latest checkpoint. |
| | References: |
| | .. [1] https://github.com/microsoft/SoftTeacher |
| | /blob/main/ssod/utils/patch.py |
| | """ |
| | if not osp.exists(path): |
| | warnings.warn('The path of checkpoints does not exist.') |
| | return None |
| | if osp.exists(osp.join(path, f'latest.{suffix}')): |
| | return osp.join(path, f'latest.{suffix}') |
| |
|
| | checkpoints = glob.glob(osp.join(path, f'*.{suffix}')) |
| | if len(checkpoints) == 0: |
| | warnings.warn('There are no checkpoints in the path.') |
| | return None |
| | latest = -1 |
| | latest_path = None |
| | for checkpoint in checkpoints: |
| | count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0]) |
| | if count > latest: |
| | latest = count |
| | latest_path = checkpoint |
| | return latest_path |
| |
|
| |
|
| | def update_data_root(cfg, logger=None): |
| | """Update data root according to env MMDET_DATASETS. |
| | |
| | If set env MMDET_DATASETS, update cfg.data_root according to |
| | MMDET_DATASETS. Otherwise, using cfg.data_root as default. |
| | |
| | Args: |
| | cfg (:obj:`Config`): The model config need to modify |
| | logger (logging.Logger | str | None): the way to print msg |
| | """ |
| | assert isinstance(cfg, Config), \ |
| | f'cfg got wrong type: {type(cfg)}, expected mmengine.Config' |
| |
|
| | if 'MMDET_DATASETS' in os.environ: |
| | dst_root = os.environ['MMDET_DATASETS'] |
| | print_log(f'MMDET_DATASETS has been set to be {dst_root}.' |
| | f'Using {dst_root} as data root.') |
| | else: |
| | return |
| |
|
| | assert isinstance(cfg, Config), \ |
| | f'cfg got wrong type: {type(cfg)}, expected mmengine.Config' |
| |
|
| | def update(cfg, src_str, dst_str): |
| | for k, v in cfg.items(): |
| | if isinstance(v, ConfigDict): |
| | update(cfg[k], src_str, dst_str) |
| | if isinstance(v, str) and src_str in v: |
| | cfg[k] = v.replace(src_str, dst_str) |
| |
|
| | update(cfg.data, cfg.data_root, dst_root) |
| | cfg.data_root = dst_root |
| |
|
| |
|
| | def get_test_pipeline_cfg(cfg: Union[str, ConfigDict]) -> ConfigDict: |
| | """Get the test dataset pipeline from entire config. |
| | |
| | Args: |
| | cfg (str or :obj:`ConfigDict`): the entire config. Can be a config |
| | file or a ``ConfigDict``. |
| | |
| | Returns: |
| | :obj:`ConfigDict`: the config of test dataset. |
| | """ |
| | if isinstance(cfg, str): |
| | cfg = Config.fromfile(cfg) |
| |
|
| | def _get_test_pipeline_cfg(dataset_cfg): |
| | if 'pipeline' in dataset_cfg: |
| | return dataset_cfg.pipeline |
| | |
| | elif 'dataset' in dataset_cfg: |
| | return _get_test_pipeline_cfg(dataset_cfg.dataset) |
| | |
| | elif 'datasets' in dataset_cfg: |
| | return _get_test_pipeline_cfg(dataset_cfg.datasets[0]) |
| |
|
| | raise RuntimeError('Cannot find `pipeline` in `test_dataloader`') |
| |
|
| | return _get_test_pipeline_cfg(cfg.test_dataloader.dataset) |
| |
|
| |
|
| | def get_file_list(source_root: str) -> [list, dict]: |
| | """Get file list. |
| | |
| | Args: |
| | source_root (str): image or video source path |
| | |
| | Return: |
| | source_file_path_list (list): A list for all source file. |
| | source_type (dict): Source type: file or url or dir. |
| | """ |
| | is_dir = os.path.isdir(source_root) |
| | is_url = source_root.startswith(('http:/', 'https:/')) |
| | is_file = os.path.splitext(source_root)[-1].lower() in IMG_EXTENSIONS |
| |
|
| | source_file_path_list = [] |
| | if is_dir: |
| | |
| | for file in scandir(source_root, IMG_EXTENSIONS, recursive=True): |
| | source_file_path_list.append(os.path.join(source_root, file)) |
| | elif is_url: |
| | |
| | filename = os.path.basename( |
| | urllib.parse.unquote(source_root).split('?')[0]) |
| | file_save_path = os.path.join(os.getcwd(), filename) |
| | print(f'Downloading source file to {file_save_path}') |
| | torch.hub.download_url_to_file(source_root, file_save_path) |
| | source_file_path_list = [file_save_path] |
| | elif is_file: |
| | |
| | source_file_path_list = [source_root] |
| | else: |
| | print('Cannot find image file.') |
| |
|
| | source_type = dict(is_dir=is_dir, is_url=is_url, is_file=is_file) |
| |
|
| | return source_file_path_list, source_type |
| |
|