| | |
| | import logging |
| | from contextlib import contextmanager |
| | from typing import Optional |
| |
|
| | import torch |
| |
|
| | from mmengine.device import (get_device, is_cuda_available, is_mlu_available, |
| | is_npu_available) |
| | from mmengine.logging import print_log |
| | from mmengine.utils import digit_version |
| | from mmengine.utils.dl_utils import TORCH_VERSION |
| |
|
| |
|
| | @contextmanager |
| | def autocast(device_type: Optional[str] = None, |
| | dtype: Optional[torch.dtype] = None, |
| | enabled: bool = True, |
| | cache_enabled: Optional[bool] = None): |
| | """A wrapper of ``torch.autocast`` and ``toch.cuda.amp.autocast``. |
| | |
| | Pytorch 1.5.0 provide ``torch.cuda.amp.autocast`` for running in |
| | mixed precision , and update it to ``torch.autocast`` in 1.10.0. |
| | Both interfaces have different arguments, and ``torch.autocast`` |
| | support running with cpu additionally. |
| | |
| | This function provides a unified interface by wrapping |
| | ``torch.autocast`` and ``torch.cuda.amp.autocast``, which resolves the |
| | compatibility issues that ``torch.cuda.amp.autocast`` does not support |
| | running mixed precision with cpu, and both contexts have different |
| | arguments. We suggest users using this function in the code |
| | to achieve maximized compatibility of different PyTorch versions. |
| | |
| | Note: |
| | ``autocast`` requires pytorch version >= 1.5.0. If pytorch version |
| | <= 1.10.0 and cuda is not available, it will raise an error with |
| | ``enabled=True``, since ``torch.cuda.amp.autocast`` only support cuda |
| | mode. |
| | |
| | Examples: |
| | >>> # case1: 1.10 > Pytorch version >= 1.5.0 |
| | >>> with autocast(): |
| | >>> # run in mixed precision context |
| | >>> pass |
| | >>> with autocast(device_type='cpu'):: |
| | >>> # raise error, torch.cuda.amp.autocast only support cuda mode. |
| | >>> pass |
| | >>> # case2: Pytorch version >= 1.10.0 |
| | >>> with autocast(): |
| | >>> # default cuda mixed precision context |
| | >>> pass |
| | >>> with autocast(device_type='cpu'): |
| | >>> # cpu mixed precision context |
| | >>> pass |
| | >>> with autocast( |
| | >>> device_type='cuda', enabled=True, cache_enabled=True): |
| | >>> # enable precision context with more specific arguments. |
| | >>> pass |
| | |
| | Args: |
| | device_type (str, required): Whether to use 'cuda' or 'cpu' device. |
| | enabled(bool): Whether autocasting should be enabled in the region. |
| | Defaults to True |
| | dtype (torch_dtype, optional): Whether to use ``torch.float16`` or |
| | ``torch.bfloat16``. |
| | cache_enabled(bool, optional): Whether the weight cache inside |
| | autocast should be enabled. |
| | """ |
| | |
| | |
| | assert digit_version(TORCH_VERSION) >= digit_version('1.5.0'), ( |
| | 'The minimum pytorch version requirements of mmengine is 1.5.0, but ' |
| | f'got {TORCH_VERSION}') |
| |
|
| | if (digit_version('1.5.0') <= digit_version(TORCH_VERSION) < |
| | digit_version('1.10.0')): |
| | |
| | |
| | assert ( |
| | device_type == 'cuda' or device_type == 'mlu' |
| | or device_type is None), ( |
| | 'Pytorch version under 1.10.0 only supports running automatic ' |
| | 'mixed training with cuda or mlu') |
| | if dtype is not None or cache_enabled is not None: |
| | print_log( |
| | f'{dtype} and {device_type} will not work for ' |
| | '`autocast` since your Pytorch version: ' |
| | f'{TORCH_VERSION} <= 1.10.0', |
| | logger='current', |
| | level=logging.WARNING) |
| |
|
| | if is_npu_available(): |
| | with torch.npu.amp.autocast(enabled=enabled): |
| | yield |
| | elif is_mlu_available(): |
| | with torch.mlu.amp.autocast(enabled=enabled): |
| | yield |
| | elif is_cuda_available(): |
| | with torch.cuda.amp.autocast(enabled=enabled): |
| | yield |
| | else: |
| | if not enabled: |
| | yield |
| | else: |
| | raise RuntimeError( |
| | 'If pytorch versions is between 1.5.0 and 1.10, ' |
| | '`autocast` is only available in gpu mode') |
| |
|
| | else: |
| | |
| | |
| | if cache_enabled is None: |
| | cache_enabled = torch.is_autocast_cache_enabled() |
| | device = get_device() |
| | device_type = device if device_type is None else device_type |
| |
|
| | if device_type == 'cuda': |
| | if dtype is None: |
| | dtype = torch.get_autocast_gpu_dtype() |
| |
|
| | if dtype == torch.bfloat16 and not \ |
| | torch.cuda.is_bf16_supported(): |
| | raise RuntimeError( |
| | 'Current CUDA Device does not support bfloat16. Please ' |
| | 'switch dtype to float16.') |
| |
|
| | elif device_type == 'cpu': |
| | if dtype is None: |
| | dtype = torch.bfloat16 |
| | assert dtype == torch.bfloat16, ( |
| | 'In CPU autocast, only support `torch.bfloat16` dtype') |
| |
|
| | elif device_type == 'mlu': |
| | pass |
| |
|
| | elif device_type == 'npu': |
| | pass |
| | elif device_type == 'musa': |
| | if dtype is None: |
| | dtype = torch.get_autocast_gpu_dtype() |
| | with torch.musa.amp.autocast( |
| | enabled=enabled, dtype=dtype, cache_enabled=cache_enabled): |
| | yield |
| | return |
| | else: |
| | |
| | |
| | |
| | if enabled is False: |
| | yield |
| | return |
| | else: |
| | raise ValueError('User specified autocast device_type must be ' |
| | f'cuda or cpu, but got {device_type}') |
| |
|
| | with torch.autocast( |
| | device_type=device_type, |
| | enabled=enabled, |
| | dtype=dtype, |
| | cache_enabled=cache_enabled): |
| | yield |
| |
|