File size: 1,474 Bytes
714cf46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | import torch
import gc
def clear_gradients(*args):
for arg in args:
if isinstance(arg, torch.Tensor) and arg.grad is not None:
arg.grad = None
def clear_memory(device):
torch._C._cuda_clearCublasWorkspaces()
torch._dynamo.reset()
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
def peak_memory(f, *args, device):
for _ in range(3):
# Clean everything
clear_memory(device)
clear_gradients(*args)
# Run once
f(*args)
# Measure peak memory
torch.cuda.synchronize()
memory = torch.cuda.max_memory_allocated(device)
return memory
def current_memory(device):
return torch.cuda.memory_allocated(device) / (1024**3)
def memory_measure(f, device, num_iters=3):
# Clean everything
clear_memory(device)
# Run measurement
print("Current memory: ", current_memory(device))
memory = peak_memory(f, device=device)
print("Peak memory: ", memory / (1024**3))
return memory / (1024**3)
def memory_measure_simple(f, device, *args, **kwargs):
# Clean everything
clear_memory(device)
clear_gradients(*args)
current = current_memory(device)
# Run once
out = f(*args, **kwargs)
# Measure peak memory
torch.cuda.synchronize()
memory = torch.cuda.max_memory_allocated(device)
memory = memory / (1024**3)
memory = memory - current
return out, memory
|