| import time, torch |
| from collections import defaultdict |
| from contextlib import contextmanager |
|
|
| class StepTimer: |
| def __init__(self, device=None): |
| self.times = defaultdict(list) |
| self.device = device |
| self._use_cuda_sync = ( |
| isinstance(device, torch.device) and device.type == "cuda" |
| ) or (isinstance(device, str) and "cuda" in device) |
|
|
| @contextmanager |
| def section(self, name): |
| if self._use_cuda_sync: |
| torch.cuda.synchronize() |
| t0 = time.perf_counter() |
| try: |
| yield |
| finally: |
| if self._use_cuda_sync: |
| torch.cuda.synchronize() |
| dt = time.perf_counter() - t0 |
| self.times[name].append(dt) |
|
|
| def summary(self, top_k=None): |
| |
| import numpy as np |
| rows = [] |
| for k, v in self.times.items(): |
| a = np.array(v, dtype=float) |
| rows.append((k, len(a), a.sum(), a.mean(), np.median(a), np.percentile(a, 95))) |
| rows.sort(key=lambda r: r[2], reverse=True) |
| return rows[:top_k] if top_k else rows |
|
|