| | |
| | """ Useful functions for writing test code. """ |
| |
|
| | import torch |
| | import torch.utils.benchmark as benchmark |
| |
|
| |
|
| | def benchmark_forward( |
| | fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs |
| | ): |
| | """Use Pytorch Benchmark on the forward pass of an arbitrary function.""" |
| | if verbose: |
| | print(desc, "- Forward pass") |
| |
|
| | def amp_wrapper(*inputs, **kwinputs): |
| | with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): |
| | fn(*inputs, **kwinputs) |
| |
|
| | t = benchmark.Timer( |
| | stmt="fn_amp(*inputs, **kwinputs)", |
| | globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs}, |
| | num_threads=torch.get_num_threads(), |
| | ) |
| | m = t.timeit(repeats) |
| | if verbose: |
| | print(m) |
| | return t, m |
| |
|
| |
|
| | def benchmark_backward( |
| | fn, |
| | *inputs, |
| | grad=None, |
| | repeats=10, |
| | desc="", |
| | verbose=True, |
| | amp=False, |
| | amp_dtype=torch.float16, |
| | **kwinputs, |
| | ): |
| | """Use Pytorch Benchmark on the backward pass of an arbitrary function.""" |
| | if verbose: |
| | print(desc, "- Backward pass") |
| | with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): |
| | y = fn(*inputs, **kwinputs) |
| | if type(y) is tuple: |
| | y = y[0] |
| | if grad is None: |
| | grad = torch.randn_like(y) |
| | else: |
| | if grad.shape != y.shape: |
| | raise RuntimeError("Grad shape does not match output shape") |
| |
|
| | def f(*inputs, y, grad): |
| | |
| | for x in inputs: |
| | if isinstance(x, torch.Tensor): |
| | x.grad = None |
| | y.backward(grad, retain_graph=True) |
| |
|
| | t = benchmark.Timer( |
| | stmt="f(*inputs, y=y, grad=grad)", |
| | globals={"f": f, "inputs": inputs, "y": y, "grad": grad}, |
| | num_threads=torch.get_num_threads(), |
| | ) |
| | m = t.timeit(repeats) |
| | if verbose: |
| | print(m) |
| | return t, m |
| |
|
| |
|
| | def benchmark_combined( |
| | fn, |
| | *inputs, |
| | grad=None, |
| | repeats=10, |
| | desc="", |
| | verbose=True, |
| | amp=False, |
| | amp_dtype=torch.float16, |
| | **kwinputs, |
| | ): |
| | """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" |
| | if verbose: |
| | print(desc, "- Forward + Backward pass") |
| | with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): |
| | y = fn(*inputs, **kwinputs) |
| | if type(y) is tuple: |
| | y = y[0] |
| | if grad is None: |
| | grad = torch.randn_like(y) |
| | else: |
| | if grad.shape != y.shape: |
| | raise RuntimeError("Grad shape does not match output shape") |
| |
|
| | def f(grad, *inputs, **kwinputs): |
| | for x in inputs: |
| | if isinstance(x, torch.Tensor): |
| | x.grad = None |
| | with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): |
| | y = fn(*inputs, **kwinputs) |
| | if type(y) is tuple: |
| | y = y[0] |
| | y.backward(grad, retain_graph=True) |
| |
|
| | t = benchmark.Timer( |
| | stmt="f(grad, *inputs, **kwinputs)", |
| | globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs}, |
| | num_threads=torch.get_num_threads(), |
| | ) |
| | m = t.timeit(repeats) |
| | if verbose: |
| | print(m) |
| | return t, m |
| |
|
| |
|
| | def benchmark_fwd_bwd( |
| | fn, |
| | *inputs, |
| | grad=None, |
| | repeats=10, |
| | desc="", |
| | verbose=True, |
| | amp=False, |
| | amp_dtype=torch.float16, |
| | **kwinputs, |
| | ): |
| | """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" |
| | return ( |
| | benchmark_forward( |
| | fn, |
| | *inputs, |
| | repeats=repeats, |
| | desc=desc, |
| | verbose=verbose, |
| | amp=amp, |
| | amp_dtype=amp_dtype, |
| | **kwinputs, |
| | ), |
| | benchmark_backward( |
| | fn, |
| | *inputs, |
| | grad=grad, |
| | repeats=repeats, |
| | desc=desc, |
| | verbose=verbose, |
| | amp=amp, |
| | amp_dtype=amp_dtype, |
| | **kwinputs, |
| | ), |
| | ) |
| |
|
| |
|
| | def benchmark_all( |
| | fn, |
| | *inputs, |
| | grad=None, |
| | repeats=10, |
| | desc="", |
| | verbose=True, |
| | amp=False, |
| | amp_dtype=torch.float16, |
| | **kwinputs, |
| | ): |
| | """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" |
| | return ( |
| | benchmark_forward( |
| | fn, |
| | *inputs, |
| | repeats=repeats, |
| | desc=desc, |
| | verbose=verbose, |
| | amp=amp, |
| | amp_dtype=amp_dtype, |
| | **kwinputs, |
| | ), |
| | benchmark_backward( |
| | fn, |
| | *inputs, |
| | grad=grad, |
| | repeats=repeats, |
| | desc=desc, |
| | verbose=verbose, |
| | amp=amp, |
| | amp_dtype=amp_dtype, |
| | **kwinputs, |
| | ), |
| | benchmark_combined( |
| | fn, |
| | *inputs, |
| | grad=grad, |
| | repeats=repeats, |
| | desc=desc, |
| | verbose=verbose, |
| | amp=amp, |
| | amp_dtype=amp_dtype, |
| | **kwinputs, |
| | ), |
| | ) |
| |
|
| |
|
| | def pytorch_profiler( |
| | fn, |
| | *inputs, |
| | trace_filename=None, |
| | backward=False, |
| | amp=False, |
| | amp_dtype=torch.float16, |
| | cpu=False, |
| | verbose=True, |
| | **kwinputs, |
| | ): |
| | """Wrap benchmark functions in Pytorch profiler to see CUDA information.""" |
| | if backward: |
| | with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): |
| | out = fn(*inputs, **kwinputs) |
| | if type(out) is tuple: |
| | out = out[0] |
| | g = torch.randn_like(out) |
| | for _ in range(30): |
| | if backward: |
| | for x in inputs: |
| | if isinstance(x, torch.Tensor): |
| | x.grad = None |
| | with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): |
| | out = fn(*inputs, **kwinputs) |
| | if type(out) is tuple: |
| | out = out[0] |
| | |
| | if backward: |
| | out.backward(g, retain_graph=True) |
| | activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [ |
| | torch.profiler.ProfilerActivity.CUDA |
| | ] |
| | with torch.profiler.profile( |
| | activities=activities, |
| | record_shapes=True, |
| | |
| | with_stack=True, |
| | ) as prof: |
| | if backward: |
| | for x in inputs: |
| | if isinstance(x, torch.Tensor): |
| | x.grad = None |
| | with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): |
| | out = fn(*inputs, **kwinputs) |
| | if type(out) is tuple: |
| | out = out[0] |
| | if backward: |
| | out.backward(g, retain_graph=True) |
| | if verbose: |
| | |
| | print(prof.key_averages().table(row_limit=50)) |
| | if trace_filename is not None: |
| | prof.export_chrome_trace(trace_filename) |
| |
|
| |
|
| | def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs): |
| | torch.cuda.empty_cache() |
| | torch.cuda.reset_peak_memory_stats() |
| | torch.cuda.synchronize() |
| | fn(*inputs, **kwinputs) |
| | torch.cuda.synchronize() |
| | mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000) |
| | if verbose: |
| | print(f"{desc} max memory: {mem}GB") |
| | torch.cuda.empty_cache() |
| | return mem |
| |
|