nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
import torch
import gc
def clear_gradients(*args):
for arg in args:
if isinstance(arg, torch.Tensor) and arg.grad is not None:
arg.grad = None
def clear_memory(device):
torch._C._cuda_clearCublasWorkspaces()
torch._dynamo.reset()
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
def peak_memory(f, *args, device):
for _ in range(3):
# Clean everything
clear_memory(device)
clear_gradients(*args)
# Run once
f(*args)
# Measure peak memory
torch.cuda.synchronize()
memory = torch.cuda.max_memory_allocated(device)
return memory
def current_memory(device):
return torch.cuda.memory_allocated(device) / (1024**3)
def memory_measure(f, device, num_iters=3):
# Clean everything
clear_memory(device)
# Run measurement
print("Current memory: ", current_memory(device))
memory = peak_memory(f, device=device)
print("Peak memory: ", memory / (1024**3))
return memory / (1024**3)
def memory_measure_simple(f, device, *args, **kwargs):
# Clean everything
clear_memory(device)
clear_gradients(*args)
current = current_memory(device)
# Run once
out = f(*args, **kwargs)
# Measure peak memory
torch.cuda.synchronize()
memory = torch.cuda.max_memory_allocated(device)
memory = memory / (1024**3)
memory = memory - current
return out, memory