| |
| import warnings |
| from collections import abc |
| from contextlib import contextmanager |
| from functools import wraps |
|
|
| import torch |
| from mmengine.logging import MMLogger |
|
|
|
|
| def cast_tensor_type(inputs, src_type=None, dst_type=None): |
| """Recursively convert Tensor in inputs from ``src_type`` to ``dst_type``. |
| |
| Args: |
| inputs: Inputs that to be casted. |
| src_type (torch.dtype | torch.device): Source type. |
| src_type (torch.dtype | torch.device): Destination type. |
| |
| Returns: |
| The same type with inputs, but all contained Tensors have been cast. |
| """ |
| assert dst_type is not None |
| if isinstance(inputs, torch.Tensor): |
| if isinstance(dst_type, torch.device): |
| |
| if hasattr(inputs, 'to') and \ |
| hasattr(inputs, 'device') and \ |
| (inputs.device == src_type or src_type is None): |
| return inputs.to(dst_type) |
| else: |
| return inputs |
| else: |
| |
| if hasattr(inputs, 'to') and \ |
| hasattr(inputs, 'dtype') and \ |
| (inputs.dtype == src_type or src_type is None): |
| return inputs.to(dst_type) |
| else: |
| return inputs |
| |
| |
| elif isinstance(inputs, abc.Mapping): |
| return type(inputs)({ |
| k: cast_tensor_type(v, src_type=src_type, dst_type=dst_type) |
| for k, v in inputs.items() |
| }) |
| elif isinstance(inputs, abc.Iterable): |
| return type(inputs)( |
| cast_tensor_type(item, src_type=src_type, dst_type=dst_type) |
| for item in inputs) |
| |
| |
| |
| |
| |
| |
| else: |
| return inputs |
|
|
|
|
| @contextmanager |
| def _ignore_torch_cuda_oom(): |
| """A context which ignores CUDA OOM exception from pytorch. |
| |
| Code is modified from |
| <https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/memory.py> # noqa: E501 |
| """ |
| try: |
| yield |
| except RuntimeError as e: |
| |
| if 'CUDA out of memory. ' in str(e): |
| pass |
| else: |
| raise |
|
|
|
|
| class AvoidOOM: |
| """Try to convert inputs to FP16 and CPU if got a PyTorch's CUDA Out of |
| Memory error. It will do the following steps: |
| |
| 1. First retry after calling `torch.cuda.empty_cache()`. |
| 2. If that still fails, it will then retry by converting inputs |
| to FP16. |
| 3. If that still fails trying to convert inputs to CPUs. |
| In this case, it expects the function to dispatch to |
| CPU implementation. |
| |
| Args: |
| to_cpu (bool): Whether to convert outputs to CPU if get an OOM |
| error. This will slow down the code significantly. |
| Defaults to True. |
| test (bool): Skip `_ignore_torch_cuda_oom` operate that can use |
| lightweight data in unit test, only used in |
| test unit. Defaults to False. |
| |
| Examples: |
| >>> from mmdet.utils.memory import AvoidOOM |
| >>> AvoidCUDAOOM = AvoidOOM() |
| >>> output = AvoidOOM.retry_if_cuda_oom( |
| >>> some_torch_function)(input1, input2) |
| >>> # To use as a decorator |
| >>> # from mmdet.utils import AvoidCUDAOOM |
| >>> @AvoidCUDAOOM.retry_if_cuda_oom |
| >>> def function(*args, **kwargs): |
| >>> return None |
| ``` |
| |
| Note: |
| 1. The output may be on CPU even if inputs are on GPU. Processing |
| on CPU will slow down the code significantly. |
| 2. When converting inputs to CPU, it will only look at each argument |
| and check if it has `.device` and `.to` for conversion. Nested |
| structures of tensors are not supported. |
| 3. Since the function might be called more than once, it has to be |
| stateless. |
| """ |
|
|
| def __init__(self, to_cpu=True, test=False): |
| self.to_cpu = to_cpu |
| self.test = test |
|
|
| def retry_if_cuda_oom(self, func): |
| """Makes a function retry itself after encountering pytorch's CUDA OOM |
| error. |
| |
| The implementation logic is referred to |
| https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/memory.py |
| |
| Args: |
| func: a stateless callable that takes tensor-like objects |
| as arguments. |
| Returns: |
| func: a callable which retries `func` if OOM is encountered. |
| """ |
|
|
| @wraps(func) |
| def wrapped(*args, **kwargs): |
|
|
| |
| if not self.test: |
| with _ignore_torch_cuda_oom(): |
| return func(*args, **kwargs) |
|
|
| |
| torch.cuda.empty_cache() |
| with _ignore_torch_cuda_oom(): |
| return func(*args, **kwargs) |
|
|
| |
| dtype, device = None, None |
| values = args + tuple(kwargs.values()) |
| for value in values: |
| if isinstance(value, torch.Tensor): |
| dtype = value.dtype |
| device = value.device |
| break |
| if dtype is None or device is None: |
| raise ValueError('There is no tensor in the inputs, ' |
| 'cannot get dtype and device.') |
|
|
| |
| fp16_args = cast_tensor_type(args, dst_type=torch.half) |
| fp16_kwargs = cast_tensor_type(kwargs, dst_type=torch.half) |
| logger = MMLogger.get_current_instance() |
| logger.warning(f'Attempting to copy inputs of {str(func)} ' |
| 'to FP16 due to CUDA OOM') |
|
|
| |
| |
| with _ignore_torch_cuda_oom(): |
| output = func(*fp16_args, **fp16_kwargs) |
| output = cast_tensor_type( |
| output, src_type=torch.half, dst_type=dtype) |
| if not self.test: |
| return output |
| logger.warning('Using FP16 still meet CUDA OOM') |
|
|
| |
| |
| if self.to_cpu: |
| logger.warning(f'Attempting to copy inputs of {str(func)} ' |
| 'to CPU due to CUDA OOM') |
| cpu_device = torch.empty(0).device |
| cpu_args = cast_tensor_type(args, dst_type=cpu_device) |
| cpu_kwargs = cast_tensor_type(kwargs, dst_type=cpu_device) |
|
|
| |
| with _ignore_torch_cuda_oom(): |
| logger.warning(f'Convert outputs to GPU (device={device})') |
| output = func(*cpu_args, **cpu_kwargs) |
| output = cast_tensor_type( |
| output, src_type=cpu_device, dst_type=device) |
| return output |
|
|
| warnings.warn('Cannot convert output to GPU due to CUDA OOM, ' |
| 'the output is now on CPU, which might cause ' |
| 'errors if the output need to interact with GPU ' |
| 'data in subsequent operations') |
| logger.warning('Cannot convert output to GPU due to ' |
| 'CUDA OOM, the output is on CPU now.') |
|
|
| return func(*cpu_args, **cpu_kwargs) |
| else: |
| |
| return func(*args, **kwargs) |
|
|
| return wrapped |
|
|
|
|
| |
| AvoidCUDAOOM = AvoidOOM() |
|
|