| | import time, torch |
| | from collections import defaultdict |
| | from contextlib import contextmanager |
| |
|
| | class StepTimer: |
| | def __init__(self, device=None): |
| | self.times = defaultdict(list) |
| | self.device = device |
| | self._use_cuda_sync = ( |
| | isinstance(device, torch.device) and device.type == "cuda" |
| | ) or (isinstance(device, str) and "cuda" in device) |
| |
|
| | @contextmanager |
| | def section(self, name): |
| | if self._use_cuda_sync: |
| | torch.cuda.synchronize() |
| | t0 = time.perf_counter() |
| | try: |
| | yield |
| | finally: |
| | if self._use_cuda_sync: |
| | torch.cuda.synchronize() |
| | dt = time.perf_counter() - t0 |
| | self.times[name].append(dt) |
| |
|
| | def summary(self, top_k=None): |
| | |
| | import numpy as np |
| | rows = [] |
| | for k, v in self.times.items(): |
| | a = np.array(v, dtype=float) |
| | rows.append((k, len(a), a.sum(), a.mean(), np.median(a), np.percentile(a, 95))) |
| | rows.sort(key=lambda r: r[2], reverse=True) |
| | return rows[:top_k] if top_k else rows |
| |
|