| |
| |
| |
| |
| |
|
|
| import os |
| from collections import namedtuple |
| from datetime import datetime |
| from typing import Any, Dict, Optional |
|
|
| import torch |
|
|
| |
| GPUMemStats = namedtuple( |
| "GPUMemStats", |
| [ |
| "max_active_gib", |
| "max_active_pct", |
| "max_reserved_gib", |
| "max_reserved_pct", |
| "num_alloc_retries", |
| "num_ooms", |
| ], |
| ) |
|
|
|
|
| class GPUMemoryMonitor: |
| def __init__(self, logger, device: str = "cuda:0"): |
| self.device = torch.device(device) |
| self.device_name = torch.cuda.get_device_name(self.device) |
| self.device_index = torch.cuda.current_device() |
| self.device_capacity = torch.cuda.get_device_properties( |
| self.device |
| ).total_memory |
| self.device_capacity_gib = self._to_gib(self.device_capacity) |
| |
| self.logger = logger |
|
|
| torch.cuda.reset_peak_memory_stats() |
| torch.cuda.empty_cache() |
|
|
| def _to_gib(self, memory_in_bytes): |
| |
| _gib_in_bytes = 1024 * 1024 * 1024 |
| memory_in_gib = memory_in_bytes / _gib_in_bytes |
| return memory_in_gib |
|
|
| def _to_pct(self, memory): |
| return 100 * memory / self.device_capacity |
|
|
| def get_peak_stats(self): |
| cuda_info = torch.cuda.memory_stats(self.device) |
|
|
| max_active = cuda_info["active_bytes.all.peak"] |
| max_active_gib = self._to_gib(max_active) |
| max_active_pct = self._to_pct(max_active) |
|
|
| max_reserved = cuda_info["reserved_bytes.all.peak"] |
| max_reserved_gib = self._to_gib(max_reserved) |
| max_reserved_pct = self._to_pct(max_reserved) |
|
|
| num_retries = cuda_info["num_alloc_retries"] |
| num_ooms = cuda_info["num_ooms"] |
|
|
| if num_retries > 0: |
| self.logger.warning(f"{num_retries} CUDA memory allocation retries.") |
| if num_ooms > 0: |
| self.logger.warning(f"{num_ooms} CUDA OOM errors thrown.") |
|
|
| return GPUMemStats( |
| max_active_gib, |
| max_active_pct, |
| max_reserved_gib, |
| max_reserved_pct, |
| num_retries, |
| num_ooms, |
| ) |
|
|
| def reset_peak_stats(self): |
| torch.cuda.reset_peak_memory_stats() |
|
|
|
|
| def build_gpu_memory_monitor(logger): |
| gpu_memory_monitor = GPUMemoryMonitor(logger, "cuda") |
| logger.info( |
| f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) " |
| f"with {gpu_memory_monitor.device_capacity_gib:.2f}GiB memory" |
| ) |
|
|
| return gpu_memory_monitor |