Kernels
optimizer / test /test_cpu_memory_peak.py
wyldecat's picture
Replace cpu_offload constructor param with turn_on/turn_off API (#26)
05a75f1 unverified
"""CPU memory peak vs offloaded tensor size verification.
Compares CPU memory usage with turn_on_cpu_offload() vs no offload to isolate
the actual CPU cost of offloading, separating it from CUDA runtime,
NCCL, and DTensor overhead.
Run with:
torchrun --nproc-per-node=8 --local-ranks-filter=0 test/test_cpu_memory_peak.py
"""
import gc
import logging
import os
import torch
import torch.distributed as dist
from torch.distributed.tensor import DTensor, Shard, distribute_tensor
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
def _setup():
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
torch.cuda.set_device(rank % torch.cuda.device_count())
return rank, dist.get_world_size()
def _make_mesh(world_size):
return dist.init_device_mesh("cuda", (world_size, ),
mesh_dim_names=("dp", ))
def get_cpu_rss_bytes():
"""Get current process RSS in bytes from /proc/self/statm."""
with open("/proc/self/statm") as f:
pages = int(f.read().split()[1])
return pages * os.sysconf("SC_PAGE_SIZE")
def get_pinned_pool_bytes(pool):
"""Get total pinned CPU buffer size from CPUOffloadPool."""
total = 0
for grp in pool._groups.values():
cpu_flat = grp["cpu_flat"]
total += cpu_flat.numel() * cpu_flat.element_size()
return total
def _run_muon_steps(mesh, dim0, dim1, num_params, num_steps, cpu_offload):
"""Run Muon optimizer steps and return final CPU RSS."""
from optimizer.muon import Muon
from optimizer.newton_schulz import set_ns_compile
set_ns_compile(False)
torch.manual_seed(42)
gc.collect()
torch.cuda.empty_cache()
params, names = [], []
for i in range(num_params):
full = torch.randn(dim0, dim1, device="cuda")
dt = distribute_tensor(full, mesh, [Shard(0)])
p = torch.nn.Parameter(dt)
params.append(p)
names.append(f"layer.{i}.weight")
param_groups = [{
"params": params,
"names": names,
"use_muon": True,
"lr": 0.02,
"weight_decay": 0.01,
"momentum": 0.95,
"nesterov": True,
"ns_steps": 5,
"none_grad": False,
}]
optim = Muon(params=param_groups, chunk_size=2, warmup_step=1)
if cpu_offload:
optim.turn_on_cpu_offload()
for step_idx in range(num_steps):
for p in params:
p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"),
mesh, [Shard(0)])
optim.step()
torch.cuda.synchronize()
gc.collect()
cpu_rss = get_cpu_rss_bytes()
pinned_bytes = 0
if cpu_offload and optim._cpu_offload_pool is not None:
pool = optim._cpu_offload_pool
pinned_bytes = get_pinned_pool_bytes(pool)
# Cleanup.
del optim, params, param_groups
gc.collect()
torch.cuda.empty_cache()
set_ns_compile(True)
return cpu_rss, pinned_bytes
def test_offload_cpu_cost_isolation(rank, world_size):
"""A/B test: measure CPU cost of offload by comparing ON vs OFF."""
mesh = _make_mesh(world_size)
dim0, dim1 = 2048, 4096
num_params = 8
num_steps = 3
if rank == 0:
logger.info("=" * 70)
logger.info("A/B TEST: CPU MEMORY COST OF OFFLOAD (ON vs OFF)")
logger.info("=" * 70)
logger.info("Config: %d params of shape (%d, %d), %d ranks, %d steps",
num_params, dim0, dim1, world_size, num_steps)
logger.info("Local param shape per rank: (%d, %d)", dim0 // world_size,
dim1)
logger.info("-" * 70)
# Run WITHOUT offload first (baseline).
gc.collect()
torch.cuda.empty_cache()
cpu_before_no_offload = get_cpu_rss_bytes()
cpu_after_no_offload, _ = _run_muon_steps(mesh,
dim0,
dim1,
num_params,
num_steps,
cpu_offload=False)
cpu_growth_no_offload = cpu_after_no_offload - cpu_before_no_offload
# Run WITH offload.
gc.collect()
torch.cuda.empty_cache()
cpu_before_offload = get_cpu_rss_bytes()
cpu_after_offload, pinned_bytes = _run_muon_steps(mesh,
dim0,
dim1,
num_params,
num_steps,
cpu_offload=True)
cpu_growth_offload = cpu_after_offload - cpu_before_offload
# Delta = additional CPU cost from offloading.
offload_delta = cpu_growth_offload - cpu_growth_no_offload
if rank == 0:
logger.info("CPU growth WITHOUT offload: %.2f MB",
cpu_growth_no_offload / 1024**2)
logger.info("CPU growth WITH offload: %.2f MB",
cpu_growth_offload / 1024**2)
logger.info("-" * 70)
logger.info("Pinned buffer size (expected): %.2f MB",
pinned_bytes / 1024**2)
logger.info("Offload delta (WITH - WITHOUT): %.2f MB",
offload_delta / 1024**2)
if pinned_bytes > 0:
ratio = offload_delta / pinned_bytes
logger.info("Ratio (delta / pinned buffer): %.2fx", ratio)
if ratio > 1.5:
logger.warning(
"Offload adds %.2f MB CPU memory but pinned buffer is "
"only %.2f MB (%.1f%% overhead beyond expected)",
offload_delta / 1024**2,
pinned_bytes / 1024**2,
(offload_delta - pinned_bytes) / pinned_bytes * 100,
)
else:
logger.info("Offload CPU cost is within expected range.")
# Only assert on rank 0 to avoid multi-rank assertion mismatches.
if rank == 0 and pinned_bytes > 0:
ratio = offload_delta / pinned_bytes
assert ratio < 3.0, (
f"Offload CPU cost ({offload_delta / 1024**2:.2f} MB) is "
f"{ratio:.2f}x the pinned buffer ({pinned_bytes / 1024**2:.2f} MB). "
f"Expected < 3.0x.")
if rank == 0:
logger.info("PASSED: test_offload_cpu_cost_isolation")
def test_cpu_memory_peak_detailed(rank, world_size):
"""Detailed per-phase CPU memory tracking for offload."""
from optimizer.muon import Muon
from optimizer.newton_schulz import set_ns_compile
set_ns_compile(False)
torch.manual_seed(42)
mesh = _make_mesh(world_size)
dim0, dim1 = 2048, 4096
num_params = 8
gc.collect()
torch.cuda.empty_cache()
if rank == 0:
logger.info("=" * 70)
logger.info("DETAILED PER-PHASE CPU MEMORY TRACKING")
logger.info("=" * 70)
cpu_0 = get_cpu_rss_bytes()
if rank == 0:
logger.info("[Phase 0] Baseline RSS: %.2f MB", cpu_0 / 1024**2)
# Create params.
params, names = [], []
for i in range(num_params):
full = torch.randn(dim0, dim1, device="cuda")
dt = distribute_tensor(full, mesh, [Shard(0)])
p = torch.nn.Parameter(dt)
params.append(p)
names.append(f"layer.{i}.weight")
gc.collect()
cpu_1 = get_cpu_rss_bytes()
if rank == 0:
logger.info("[Phase 1] After param creation: %.2f MB (+%.2f MB)",
cpu_1 / 1024**2, (cpu_1 - cpu_0) / 1024**2)
# Create optimizer.
param_groups = [{
"params": params,
"names": names,
"use_muon": True,
"lr": 0.02,
"weight_decay": 0.01,
"momentum": 0.95,
"nesterov": True,
"ns_steps": 5,
"none_grad": False,
}]
optim = Muon(params=param_groups, chunk_size=2, warmup_step=1)
optim.turn_on_cpu_offload()
gc.collect()
cpu_2 = get_cpu_rss_bytes()
if rank == 0:
logger.info("[Phase 2] After optimizer creation: %.2f MB (+%.2f MB)",
cpu_2 / 1024**2, (cpu_2 - cpu_1) / 1024**2)
# Set grads.
for p in params:
p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"),
mesh, [Shard(0)])
gc.collect()
cpu_3 = get_cpu_rss_bytes()
if rank == 0:
logger.info("[Phase 3] After grad creation: %.2f MB (+%.2f MB)",
cpu_3 / 1024**2, (cpu_3 - cpu_2) / 1024**2)
# Step 1 (creates states + first offload).
optim.step()
torch.cuda.synchronize()
gc.collect()
cpu_4 = get_cpu_rss_bytes()
pool = optim._cpu_offload_pool
pinned_bytes = get_pinned_pool_bytes(pool)
if rank == 0:
logger.info(
"[Phase 4] After step 1 (init+offload): %.2f MB (+%.2f MB)",
cpu_4 / 1024**2, (cpu_4 - cpu_3) / 1024**2)
logger.info(" Pinned buffer size: %.2f MB", pinned_bytes / 1024**2)
logger.info(" Step 1 growth vs pinned: %.2f MB extra",
(cpu_4 - cpu_3 - pinned_bytes) / 1024**2)
# Step 2 (reload + compute + offload).
for p in params:
p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"),
mesh, [Shard(0)])
optim.step()
torch.cuda.synchronize()
gc.collect()
cpu_5 = get_cpu_rss_bytes()
if rank == 0:
logger.info("[Phase 5] After step 2: %.2f MB (+%.2f MB)",
cpu_5 / 1024**2, (cpu_5 - cpu_4) / 1024**2)
# Step 3.
for p in params:
p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"),
mesh, [Shard(0)])
optim.step()
torch.cuda.synchronize()
gc.collect()
cpu_6 = get_cpu_rss_bytes()
if rank == 0:
logger.info("[Phase 6] After step 3: %.2f MB (+%.2f MB)",
cpu_6 / 1024**2, (cpu_6 - cpu_5) / 1024**2)
# Summary.
total_growth = cpu_6 - cpu_0
if rank == 0:
logger.info("-" * 70)
logger.info("SUMMARY:")
logger.info(" Total CPU growth: %.2f MB", total_growth / 1024**2)
logger.info(" Pinned buffer: %.2f MB", pinned_bytes / 1024**2)
logger.info(" Overhead: %.2f MB",
(total_growth - pinned_bytes) / 1024**2)
if pinned_bytes > 0:
logger.info(" Ratio: %.2fx",
total_growth / pinned_bytes)
logger.info("")
logger.info(" NOTE: Overhead includes CUDA runtime, NCCL buffers,")
logger.info(" DTensor metadata, and optimizer internals — NOT just")
logger.info(" offload cost. Use A/B test for isolated measurement.")
set_ns_compile(True)
if rank == 0:
logger.info("PASSED: test_cpu_memory_peak_detailed")
def test_offload_cpu_cost_mixed(rank, world_size):
"""A/B test for mixed Muon + AdamW offload CPU cost."""
from optimizer.muon import Muon
from optimizer.newton_schulz import set_ns_compile
mesh = _make_mesh(world_size)
muon_dim0, muon_dim1 = 2048, 4096
num_muon = 8
adamw_dim = 4096
num_adamw = 8
num_steps = 3
def run_mixed(cpu_offload):
set_ns_compile(False)
torch.manual_seed(42)
gc.collect()
torch.cuda.empty_cache()
muon_params, muon_names = [], []
for i in range(num_muon):
full = torch.randn(muon_dim0, muon_dim1, device="cuda")
dt = distribute_tensor(full, mesh, [Shard(0)])
p = torch.nn.Parameter(dt)
muon_params.append(p)
muon_names.append(f"layer.{i}.weight")
adamw_params = []
for i in range(num_adamw):
full = torch.randn(adamw_dim, device="cuda")
dt = distribute_tensor(full, mesh, [Shard(0)])
p = torch.nn.Parameter(dt)
adamw_params.append(p)
param_groups = [
{
"params": muon_params,
"names": muon_names,
"use_muon": True,
"lr": 0.02,
"weight_decay": 0.01,
"momentum": 0.95,
"nesterov": True,
"ns_steps": 5,
"none_grad": False,
"adamw_betas": (0.9, 0.95),
"adamw_eps": 1e-8,
},
{
"params": adamw_params,
"use_muon": False,
"lr": 1e-3,
"weight_decay": 0.01,
"adamw_betas": (0.9, 0.95),
"adamw_eps": 1e-8,
},
]
optim = Muon(params=param_groups, chunk_size=2, warmup_step=1)
if cpu_offload:
optim.turn_on_cpu_offload()
for step_idx in range(num_steps):
for p in muon_params:
p.grad = distribute_tensor(
torch.randn(muon_dim0, muon_dim1, device="cuda"), mesh,
[Shard(0)])
for p in adamw_params:
p.grad = distribute_tensor(
torch.randn(adamw_dim, device="cuda"), mesh, [Shard(0)])
optim.step()
torch.cuda.synchronize()
gc.collect()
cpu_rss = get_cpu_rss_bytes()
pinned_bytes = 0
if cpu_offload and optim._cpu_offload_pool is not None:
pinned_bytes = get_pinned_pool_bytes(optim._cpu_offload_pool)
del optim, muon_params, adamw_params, param_groups
gc.collect()
torch.cuda.empty_cache()
set_ns_compile(True)
return cpu_rss, pinned_bytes
if rank == 0:
logger.info("=" * 70)
logger.info("A/B TEST: CPU COST OF MIXED OFFLOAD (Muon + AdamW)")
logger.info("=" * 70)
gc.collect()
torch.cuda.empty_cache()
cpu_before_no = get_cpu_rss_bytes()
cpu_after_no, _ = run_mixed(False)
growth_no = cpu_after_no - cpu_before_no
gc.collect()
torch.cuda.empty_cache()
cpu_before_yes = get_cpu_rss_bytes()
cpu_after_yes, pinned_bytes = run_mixed(True)
growth_yes = cpu_after_yes - cpu_before_yes
delta = growth_yes - growth_no
if rank == 0:
logger.info("CPU growth WITHOUT offload: %.2f MB", growth_no / 1024**2)
logger.info("CPU growth WITH offload: %.2f MB",
growth_yes / 1024**2)
logger.info("Pinned buffer size: %.2f MB",
pinned_bytes / 1024**2)
logger.info("Offload delta: %.2f MB", delta / 1024**2)
if pinned_bytes > 0:
logger.info("Ratio (delta / pinned): %.2fx",
delta / pinned_bytes)
if rank == 0 and pinned_bytes > 0:
ratio = delta / pinned_bytes
assert ratio < 3.0, (
f"Mixed offload CPU cost ({delta / 1024**2:.2f} MB) is "
f"{ratio:.2f}x the pinned buffer ({pinned_bytes / 1024**2:.2f} MB)."
)
if rank == 0:
logger.info("PASSED: test_offload_cpu_cost_mixed")
def test_pinned_memory_rss_overhead(rank, world_size):
"""Isolate: does cudaHostAlloc itself cause 2x RSS overhead?"""
sizes_mb = [8, 16, 32, 64, 128]
if rank == 0:
logger.info("=" * 70)
logger.info("ISOLATED TEST: PINNED MEMORY RSS OVERHEAD")
logger.info("=" * 70)
for size_mb in sizes_mb:
numel = size_mb * 1024 * 1024 // 4 # float32
# Test 1: pin_memory=True (direct allocation).
gc.collect()
torch.cuda.empty_cache()
rss_before = get_cpu_rss_bytes()
t1 = torch.empty(numel,
dtype=torch.float32,
device="cpu",
pin_memory=True)
rss_after = get_cpu_rss_bytes()
rss_growth_direct = rss_after - rss_before
del t1
gc.collect()
# Test 2: .pin_memory() (copy-based).
gc.collect()
torch.cuda.empty_cache()
rss_before2 = get_cpu_rss_bytes()
t2 = torch.empty(numel, dtype=torch.float32, device="cpu").pin_memory()
rss_after2 = get_cpu_rss_bytes()
rss_growth_copy = rss_after2 - rss_before2
del t2
gc.collect()
# Test 3: regular (non-pinned) CPU allocation.
gc.collect()
torch.cuda.empty_cache()
rss_before3 = get_cpu_rss_bytes()
t3 = torch.empty(numel, dtype=torch.float32, device="cpu")
# Touch all pages to ensure RSS reflects actual allocation.
t3.fill_(1.0)
rss_after3 = get_cpu_rss_bytes()
rss_growth_regular = rss_after3 - rss_before3
del t3
gc.collect()
if rank == 0:
logger.info(
"%3d MB: pin_memory=True → RSS +%.1f MB (%.2fx) | "
".pin_memory() → RSS +%.1f MB (%.2fx) | "
"regular → RSS +%.1f MB (%.2fx)",
size_mb,
rss_growth_direct / 1024**2,
rss_growth_direct / (size_mb * 1024**2) if size_mb > 0 else 0,
rss_growth_copy / 1024**2,
rss_growth_copy / (size_mb * 1024**2) if size_mb > 0 else 0,
rss_growth_regular / 1024**2,
rss_growth_regular / (size_mb * 1024**2) if size_mb > 0 else 0,
)
if rank == 0:
logger.info("PASSED: test_pinned_memory_rss_overhead")
def main():
rank, world_size = _setup()
try:
test_pinned_memory_rss_overhead(rank, world_size)
test_cpu_memory_peak_detailed(rank, world_size)
test_offload_cpu_cost_isolation(rank, world_size)
test_offload_cpu_cost_mixed(rank, world_size)
if rank == 0:
logger.info("=" * 50)
logger.info("ALL CPU MEMORY PEAK TESTS PASSED")
logger.info("=" * 50)
finally:
dist.destroy_process_group()
if __name__ == "__main__":
main()