| | import torch |
| | import gc |
| |
|
| |
|
| | def clear_gradients(*args): |
| | for arg in args: |
| | if isinstance(arg, torch.Tensor) and arg.grad is not None: |
| | arg.grad = None |
| |
|
| |
|
| | def clear_memory(device): |
| | torch._C._cuda_clearCublasWorkspaces() |
| | torch._dynamo.reset() |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | torch.cuda.reset_peak_memory_stats(device) |
| |
|
| |
|
| | def peak_memory(f, *args, device): |
| | for _ in range(3): |
| | |
| | clear_memory(device) |
| | clear_gradients(*args) |
| |
|
| | |
| | f(*args) |
| |
|
| | |
| | torch.cuda.synchronize() |
| | memory = torch.cuda.max_memory_allocated(device) |
| |
|
| | return memory |
| |
|
| |
|
| | def current_memory(device): |
| | return torch.cuda.memory_allocated(device) / (1024**3) |
| |
|
| |
|
| | def memory_measure(f, device, num_iters=3): |
| | |
| | clear_memory(device) |
| |
|
| | |
| | print("Current memory: ", current_memory(device)) |
| | memory = peak_memory(f, device=device) |
| |
|
| | print("Peak memory: ", memory / (1024**3)) |
| | return memory / (1024**3) |
| |
|
| |
|
| | def memory_measure_simple(f, device, *args, **kwargs): |
| | |
| | clear_memory(device) |
| | clear_gradients(*args) |
| |
|
| | current = current_memory(device) |
| |
|
| | |
| | out = f(*args, **kwargs) |
| |
|
| | |
| | torch.cuda.synchronize() |
| | memory = torch.cuda.max_memory_allocated(device) |
| | memory = memory / (1024**3) |
| | memory = memory - current |
| |
|
| | return out, memory |
| |
|