File size: 1,474 Bytes
714cf46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import gc


def clear_gradients(*args):
    for arg in args:
        if isinstance(arg, torch.Tensor) and arg.grad is not None:
            arg.grad = None


def clear_memory(device):
    torch._C._cuda_clearCublasWorkspaces()
    torch._dynamo.reset()
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats(device)


def peak_memory(f, *args, device):
    for _ in range(3):
        # Clean everything
        clear_memory(device)
        clear_gradients(*args)

        # Run once
        f(*args)

        # Measure peak memory
        torch.cuda.synchronize()
        memory = torch.cuda.max_memory_allocated(device)

    return memory


def current_memory(device):
    return torch.cuda.memory_allocated(device) / (1024**3)


def memory_measure(f, device, num_iters=3):
    # Clean everything
    clear_memory(device)

    # Run measurement
    print("Current memory: ", current_memory(device))
    memory = peak_memory(f, device=device)

    print("Peak memory: ", memory / (1024**3))
    return memory / (1024**3)


def memory_measure_simple(f, device, *args, **kwargs):
    # Clean everything
    clear_memory(device)
    clear_gradients(*args)

    current = current_memory(device)

    # Run once
    out = f(*args, **kwargs)

    # Measure peak memory
    torch.cuda.synchronize()
    memory = torch.cuda.max_memory_allocated(device)
    memory = memory / (1024**3)
    memory = memory - current

    return out, memory