|
|
|
|
| import logging
|
| from contextlib import contextmanager
|
| from functools import wraps
|
| import torch
|
|
|
| __all__ = ["retry_if_cuda_oom"]
|
|
|
|
|
| @contextmanager
|
| def _ignore_torch_cuda_oom():
|
| """
|
| A context which ignores CUDA OOM exception from pytorch.
|
| """
|
| try:
|
| yield
|
| except RuntimeError as e:
|
|
|
| if "CUDA out of memory. " in str(e):
|
| pass
|
| else:
|
| raise
|
|
|
|
|
| def retry_if_cuda_oom(func):
|
| """
|
| Makes a function retry itself after encountering
|
| pytorch's CUDA OOM error.
|
| It will first retry after calling `torch.cuda.empty_cache()`.
|
|
|
| If that still fails, it will then retry by trying to convert inputs to CPUs.
|
| In this case, it expects the function to dispatch to CPU implementation.
|
| The return values may become CPU tensors as well and it's user's
|
| responsibility to convert it back to CUDA tensor if needed.
|
|
|
| Args:
|
| func: a stateless callable that takes tensor-like objects as arguments
|
|
|
| Returns:
|
| a callable which retries `func` if OOM is encountered.
|
|
|
| Examples:
|
| ::
|
| output = retry_if_cuda_oom(some_torch_function)(input1, input2)
|
| # output may be on CPU even if inputs are on GPU
|
|
|
| Note:
|
| 1. When converting inputs to CPU, it will only look at each argument and check
|
| if it has `.device` and `.to` for conversion. Nested structures of tensors
|
| are not supported.
|
|
|
| 2. Since the function might be called more than once, it has to be
|
| stateless.
|
| """
|
|
|
| def maybe_to_cpu(x):
|
| try:
|
| like_gpu_tensor = x.device.type == "cuda" and hasattr(x, "to")
|
| except AttributeError:
|
| like_gpu_tensor = False
|
| if like_gpu_tensor:
|
| return x.to(device="cpu")
|
| else:
|
| return x
|
|
|
| @wraps(func)
|
| def wrapped(*args, **kwargs):
|
| with _ignore_torch_cuda_oom():
|
| return func(*args, **kwargs)
|
|
|
|
|
| torch.cuda.empty_cache()
|
| with _ignore_torch_cuda_oom():
|
| return func(*args, **kwargs)
|
|
|
|
|
| logger = logging.getLogger(__name__)
|
| logger.info("Attempting to copy inputs of {} to CPU due to CUDA OOM".format(str(func)))
|
| new_args = (maybe_to_cpu(x) for x in args)
|
| new_kwargs = {k: maybe_to_cpu(v) for k, v in kwargs.items()}
|
| return func(*new_args, **new_kwargs)
|
|
|
| return wrapped
|
|
|