| """CPU memory peak vs offloaded tensor size verification. |
| |
| Compares CPU memory usage with turn_on_cpu_offload() vs no offload to isolate |
| the actual CPU cost of offloading, separating it from CUDA runtime, |
| NCCL, and DTensor overhead. |
| |
| Run with: |
| torchrun --nproc-per-node=8 --local-ranks-filter=0 test/test_cpu_memory_peak.py |
| """ |
|
|
| import gc |
| import logging |
| import os |
|
|
| import torch |
| import torch.distributed as dist |
| from torch.distributed.tensor import DTensor, Shard, distribute_tensor |
|
|
| logger = logging.getLogger(__name__) |
| logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s") |
|
|
|
|
| def _setup(): |
| dist.init_process_group(backend="nccl") |
| rank = dist.get_rank() |
| torch.cuda.set_device(rank % torch.cuda.device_count()) |
| return rank, dist.get_world_size() |
|
|
|
|
| def _make_mesh(world_size): |
| return dist.init_device_mesh("cuda", (world_size, ), |
| mesh_dim_names=("dp", )) |
|
|
|
|
| def get_cpu_rss_bytes(): |
| """Get current process RSS in bytes from /proc/self/statm.""" |
| with open("/proc/self/statm") as f: |
| pages = int(f.read().split()[1]) |
| return pages * os.sysconf("SC_PAGE_SIZE") |
|
|
|
|
| def get_pinned_pool_bytes(pool): |
| """Get total pinned CPU buffer size from CPUOffloadPool.""" |
| total = 0 |
| for grp in pool._groups.values(): |
| cpu_flat = grp["cpu_flat"] |
| total += cpu_flat.numel() * cpu_flat.element_size() |
| return total |
|
|
|
|
| def _run_muon_steps(mesh, dim0, dim1, num_params, num_steps, cpu_offload): |
| """Run Muon optimizer steps and return final CPU RSS.""" |
| from optimizer.muon import Muon |
| from optimizer.newton_schulz import set_ns_compile |
|
|
| set_ns_compile(False) |
| torch.manual_seed(42) |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| params, names = [], [] |
| for i in range(num_params): |
| full = torch.randn(dim0, dim1, device="cuda") |
| dt = distribute_tensor(full, mesh, [Shard(0)]) |
| p = torch.nn.Parameter(dt) |
| params.append(p) |
| names.append(f"layer.{i}.weight") |
|
|
| param_groups = [{ |
| "params": params, |
| "names": names, |
| "use_muon": True, |
| "lr": 0.02, |
| "weight_decay": 0.01, |
| "momentum": 0.95, |
| "nesterov": True, |
| "ns_steps": 5, |
| "none_grad": False, |
| }] |
|
|
| optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) |
| if cpu_offload: |
| optim.turn_on_cpu_offload() |
|
|
| for step_idx in range(num_steps): |
| for p in params: |
| p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), |
| mesh, [Shard(0)]) |
| optim.step() |
| torch.cuda.synchronize() |
|
|
| gc.collect() |
| cpu_rss = get_cpu_rss_bytes() |
|
|
| pinned_bytes = 0 |
| if cpu_offload and optim._cpu_offload_pool is not None: |
| pool = optim._cpu_offload_pool |
| pinned_bytes = get_pinned_pool_bytes(pool) |
|
|
| |
| del optim, params, param_groups |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| set_ns_compile(True) |
| return cpu_rss, pinned_bytes |
|
|
|
|
| def test_offload_cpu_cost_isolation(rank, world_size): |
| """A/B test: measure CPU cost of offload by comparing ON vs OFF.""" |
| mesh = _make_mesh(world_size) |
|
|
| dim0, dim1 = 2048, 4096 |
| num_params = 8 |
| num_steps = 3 |
|
|
| if rank == 0: |
| logger.info("=" * 70) |
| logger.info("A/B TEST: CPU MEMORY COST OF OFFLOAD (ON vs OFF)") |
| logger.info("=" * 70) |
| logger.info("Config: %d params of shape (%d, %d), %d ranks, %d steps", |
| num_params, dim0, dim1, world_size, num_steps) |
| logger.info("Local param shape per rank: (%d, %d)", dim0 // world_size, |
| dim1) |
| logger.info("-" * 70) |
|
|
| |
| gc.collect() |
| torch.cuda.empty_cache() |
| cpu_before_no_offload = get_cpu_rss_bytes() |
| cpu_after_no_offload, _ = _run_muon_steps(mesh, |
| dim0, |
| dim1, |
| num_params, |
| num_steps, |
| cpu_offload=False) |
| cpu_growth_no_offload = cpu_after_no_offload - cpu_before_no_offload |
|
|
| |
| gc.collect() |
| torch.cuda.empty_cache() |
| cpu_before_offload = get_cpu_rss_bytes() |
| cpu_after_offload, pinned_bytes = _run_muon_steps(mesh, |
| dim0, |
| dim1, |
| num_params, |
| num_steps, |
| cpu_offload=True) |
| cpu_growth_offload = cpu_after_offload - cpu_before_offload |
|
|
| |
| offload_delta = cpu_growth_offload - cpu_growth_no_offload |
|
|
| if rank == 0: |
| logger.info("CPU growth WITHOUT offload: %.2f MB", |
| cpu_growth_no_offload / 1024**2) |
| logger.info("CPU growth WITH offload: %.2f MB", |
| cpu_growth_offload / 1024**2) |
| logger.info("-" * 70) |
| logger.info("Pinned buffer size (expected): %.2f MB", |
| pinned_bytes / 1024**2) |
| logger.info("Offload delta (WITH - WITHOUT): %.2f MB", |
| offload_delta / 1024**2) |
|
|
| if pinned_bytes > 0: |
| ratio = offload_delta / pinned_bytes |
| logger.info("Ratio (delta / pinned buffer): %.2fx", ratio) |
|
|
| if ratio > 1.5: |
| logger.warning( |
| "Offload adds %.2f MB CPU memory but pinned buffer is " |
| "only %.2f MB (%.1f%% overhead beyond expected)", |
| offload_delta / 1024**2, |
| pinned_bytes / 1024**2, |
| (offload_delta - pinned_bytes) / pinned_bytes * 100, |
| ) |
| else: |
| logger.info("Offload CPU cost is within expected range.") |
|
|
| |
| if rank == 0 and pinned_bytes > 0: |
| ratio = offload_delta / pinned_bytes |
| assert ratio < 3.0, ( |
| f"Offload CPU cost ({offload_delta / 1024**2:.2f} MB) is " |
| f"{ratio:.2f}x the pinned buffer ({pinned_bytes / 1024**2:.2f} MB). " |
| f"Expected < 3.0x.") |
|
|
| if rank == 0: |
| logger.info("PASSED: test_offload_cpu_cost_isolation") |
|
|
|
|
| def test_cpu_memory_peak_detailed(rank, world_size): |
| """Detailed per-phase CPU memory tracking for offload.""" |
| from optimizer.muon import Muon |
| from optimizer.newton_schulz import set_ns_compile |
|
|
| set_ns_compile(False) |
| torch.manual_seed(42) |
|
|
| mesh = _make_mesh(world_size) |
|
|
| dim0, dim1 = 2048, 4096 |
| num_params = 8 |
|
|
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| if rank == 0: |
| logger.info("=" * 70) |
| logger.info("DETAILED PER-PHASE CPU MEMORY TRACKING") |
| logger.info("=" * 70) |
|
|
| cpu_0 = get_cpu_rss_bytes() |
| if rank == 0: |
| logger.info("[Phase 0] Baseline RSS: %.2f MB", cpu_0 / 1024**2) |
|
|
| |
| params, names = [], [] |
| for i in range(num_params): |
| full = torch.randn(dim0, dim1, device="cuda") |
| dt = distribute_tensor(full, mesh, [Shard(0)]) |
| p = torch.nn.Parameter(dt) |
| params.append(p) |
| names.append(f"layer.{i}.weight") |
|
|
| gc.collect() |
| cpu_1 = get_cpu_rss_bytes() |
| if rank == 0: |
| logger.info("[Phase 1] After param creation: %.2f MB (+%.2f MB)", |
| cpu_1 / 1024**2, (cpu_1 - cpu_0) / 1024**2) |
|
|
| |
| param_groups = [{ |
| "params": params, |
| "names": names, |
| "use_muon": True, |
| "lr": 0.02, |
| "weight_decay": 0.01, |
| "momentum": 0.95, |
| "nesterov": True, |
| "ns_steps": 5, |
| "none_grad": False, |
| }] |
| optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) |
| optim.turn_on_cpu_offload() |
|
|
| gc.collect() |
| cpu_2 = get_cpu_rss_bytes() |
| if rank == 0: |
| logger.info("[Phase 2] After optimizer creation: %.2f MB (+%.2f MB)", |
| cpu_2 / 1024**2, (cpu_2 - cpu_1) / 1024**2) |
|
|
| |
| for p in params: |
| p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), |
| mesh, [Shard(0)]) |
|
|
| gc.collect() |
| cpu_3 = get_cpu_rss_bytes() |
| if rank == 0: |
| logger.info("[Phase 3] After grad creation: %.2f MB (+%.2f MB)", |
| cpu_3 / 1024**2, (cpu_3 - cpu_2) / 1024**2) |
|
|
| |
| optim.step() |
| torch.cuda.synchronize() |
| gc.collect() |
| cpu_4 = get_cpu_rss_bytes() |
|
|
| pool = optim._cpu_offload_pool |
| pinned_bytes = get_pinned_pool_bytes(pool) |
|
|
| if rank == 0: |
| logger.info( |
| "[Phase 4] After step 1 (init+offload): %.2f MB (+%.2f MB)", |
| cpu_4 / 1024**2, (cpu_4 - cpu_3) / 1024**2) |
| logger.info(" Pinned buffer size: %.2f MB", pinned_bytes / 1024**2) |
| logger.info(" Step 1 growth vs pinned: %.2f MB extra", |
| (cpu_4 - cpu_3 - pinned_bytes) / 1024**2) |
|
|
| |
| for p in params: |
| p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), |
| mesh, [Shard(0)]) |
| optim.step() |
| torch.cuda.synchronize() |
| gc.collect() |
| cpu_5 = get_cpu_rss_bytes() |
| if rank == 0: |
| logger.info("[Phase 5] After step 2: %.2f MB (+%.2f MB)", |
| cpu_5 / 1024**2, (cpu_5 - cpu_4) / 1024**2) |
|
|
| |
| for p in params: |
| p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), |
| mesh, [Shard(0)]) |
| optim.step() |
| torch.cuda.synchronize() |
| gc.collect() |
| cpu_6 = get_cpu_rss_bytes() |
| if rank == 0: |
| logger.info("[Phase 6] After step 3: %.2f MB (+%.2f MB)", |
| cpu_6 / 1024**2, (cpu_6 - cpu_5) / 1024**2) |
|
|
| |
| total_growth = cpu_6 - cpu_0 |
| if rank == 0: |
| logger.info("-" * 70) |
| logger.info("SUMMARY:") |
| logger.info(" Total CPU growth: %.2f MB", total_growth / 1024**2) |
| logger.info(" Pinned buffer: %.2f MB", pinned_bytes / 1024**2) |
| logger.info(" Overhead: %.2f MB", |
| (total_growth - pinned_bytes) / 1024**2) |
| if pinned_bytes > 0: |
| logger.info(" Ratio: %.2fx", |
| total_growth / pinned_bytes) |
| logger.info("") |
| logger.info(" NOTE: Overhead includes CUDA runtime, NCCL buffers,") |
| logger.info(" DTensor metadata, and optimizer internals — NOT just") |
| logger.info(" offload cost. Use A/B test for isolated measurement.") |
|
|
| set_ns_compile(True) |
| if rank == 0: |
| logger.info("PASSED: test_cpu_memory_peak_detailed") |
|
|
|
|
| def test_offload_cpu_cost_mixed(rank, world_size): |
| """A/B test for mixed Muon + AdamW offload CPU cost.""" |
| from optimizer.muon import Muon |
| from optimizer.newton_schulz import set_ns_compile |
|
|
| mesh = _make_mesh(world_size) |
|
|
| muon_dim0, muon_dim1 = 2048, 4096 |
| num_muon = 8 |
| adamw_dim = 4096 |
| num_adamw = 8 |
| num_steps = 3 |
|
|
| def run_mixed(cpu_offload): |
| set_ns_compile(False) |
| torch.manual_seed(42) |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| muon_params, muon_names = [], [] |
| for i in range(num_muon): |
| full = torch.randn(muon_dim0, muon_dim1, device="cuda") |
| dt = distribute_tensor(full, mesh, [Shard(0)]) |
| p = torch.nn.Parameter(dt) |
| muon_params.append(p) |
| muon_names.append(f"layer.{i}.weight") |
|
|
| adamw_params = [] |
| for i in range(num_adamw): |
| full = torch.randn(adamw_dim, device="cuda") |
| dt = distribute_tensor(full, mesh, [Shard(0)]) |
| p = torch.nn.Parameter(dt) |
| adamw_params.append(p) |
|
|
| param_groups = [ |
| { |
| "params": muon_params, |
| "names": muon_names, |
| "use_muon": True, |
| "lr": 0.02, |
| "weight_decay": 0.01, |
| "momentum": 0.95, |
| "nesterov": True, |
| "ns_steps": 5, |
| "none_grad": False, |
| "adamw_betas": (0.9, 0.95), |
| "adamw_eps": 1e-8, |
| }, |
| { |
| "params": adamw_params, |
| "use_muon": False, |
| "lr": 1e-3, |
| "weight_decay": 0.01, |
| "adamw_betas": (0.9, 0.95), |
| "adamw_eps": 1e-8, |
| }, |
| ] |
|
|
| optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) |
| if cpu_offload: |
| optim.turn_on_cpu_offload() |
|
|
| for step_idx in range(num_steps): |
| for p in muon_params: |
| p.grad = distribute_tensor( |
| torch.randn(muon_dim0, muon_dim1, device="cuda"), mesh, |
| [Shard(0)]) |
| for p in adamw_params: |
| p.grad = distribute_tensor( |
| torch.randn(adamw_dim, device="cuda"), mesh, [Shard(0)]) |
| optim.step() |
| torch.cuda.synchronize() |
|
|
| gc.collect() |
| cpu_rss = get_cpu_rss_bytes() |
|
|
| pinned_bytes = 0 |
| if cpu_offload and optim._cpu_offload_pool is not None: |
| pinned_bytes = get_pinned_pool_bytes(optim._cpu_offload_pool) |
|
|
| del optim, muon_params, adamw_params, param_groups |
| gc.collect() |
| torch.cuda.empty_cache() |
| set_ns_compile(True) |
| return cpu_rss, pinned_bytes |
|
|
| if rank == 0: |
| logger.info("=" * 70) |
| logger.info("A/B TEST: CPU COST OF MIXED OFFLOAD (Muon + AdamW)") |
| logger.info("=" * 70) |
|
|
| gc.collect() |
| torch.cuda.empty_cache() |
| cpu_before_no = get_cpu_rss_bytes() |
| cpu_after_no, _ = run_mixed(False) |
| growth_no = cpu_after_no - cpu_before_no |
|
|
| gc.collect() |
| torch.cuda.empty_cache() |
| cpu_before_yes = get_cpu_rss_bytes() |
| cpu_after_yes, pinned_bytes = run_mixed(True) |
| growth_yes = cpu_after_yes - cpu_before_yes |
|
|
| delta = growth_yes - growth_no |
|
|
| if rank == 0: |
| logger.info("CPU growth WITHOUT offload: %.2f MB", growth_no / 1024**2) |
| logger.info("CPU growth WITH offload: %.2f MB", |
| growth_yes / 1024**2) |
| logger.info("Pinned buffer size: %.2f MB", |
| pinned_bytes / 1024**2) |
| logger.info("Offload delta: %.2f MB", delta / 1024**2) |
| if pinned_bytes > 0: |
| logger.info("Ratio (delta / pinned): %.2fx", |
| delta / pinned_bytes) |
|
|
| if rank == 0 and pinned_bytes > 0: |
| ratio = delta / pinned_bytes |
| assert ratio < 3.0, ( |
| f"Mixed offload CPU cost ({delta / 1024**2:.2f} MB) is " |
| f"{ratio:.2f}x the pinned buffer ({pinned_bytes / 1024**2:.2f} MB)." |
| ) |
|
|
| if rank == 0: |
| logger.info("PASSED: test_offload_cpu_cost_mixed") |
|
|
|
|
| def test_pinned_memory_rss_overhead(rank, world_size): |
| """Isolate: does cudaHostAlloc itself cause 2x RSS overhead?""" |
| sizes_mb = [8, 16, 32, 64, 128] |
|
|
| if rank == 0: |
| logger.info("=" * 70) |
| logger.info("ISOLATED TEST: PINNED MEMORY RSS OVERHEAD") |
| logger.info("=" * 70) |
|
|
| for size_mb in sizes_mb: |
| numel = size_mb * 1024 * 1024 // 4 |
|
|
| |
| gc.collect() |
| torch.cuda.empty_cache() |
| rss_before = get_cpu_rss_bytes() |
| t1 = torch.empty(numel, |
| dtype=torch.float32, |
| device="cpu", |
| pin_memory=True) |
| rss_after = get_cpu_rss_bytes() |
| rss_growth_direct = rss_after - rss_before |
| del t1 |
| gc.collect() |
|
|
| |
| gc.collect() |
| torch.cuda.empty_cache() |
| rss_before2 = get_cpu_rss_bytes() |
| t2 = torch.empty(numel, dtype=torch.float32, device="cpu").pin_memory() |
| rss_after2 = get_cpu_rss_bytes() |
| rss_growth_copy = rss_after2 - rss_before2 |
| del t2 |
| gc.collect() |
|
|
| |
| gc.collect() |
| torch.cuda.empty_cache() |
| rss_before3 = get_cpu_rss_bytes() |
| t3 = torch.empty(numel, dtype=torch.float32, device="cpu") |
| |
| t3.fill_(1.0) |
| rss_after3 = get_cpu_rss_bytes() |
| rss_growth_regular = rss_after3 - rss_before3 |
| del t3 |
| gc.collect() |
|
|
| if rank == 0: |
| logger.info( |
| "%3d MB: pin_memory=True → RSS +%.1f MB (%.2fx) | " |
| ".pin_memory() → RSS +%.1f MB (%.2fx) | " |
| "regular → RSS +%.1f MB (%.2fx)", |
| size_mb, |
| rss_growth_direct / 1024**2, |
| rss_growth_direct / (size_mb * 1024**2) if size_mb > 0 else 0, |
| rss_growth_copy / 1024**2, |
| rss_growth_copy / (size_mb * 1024**2) if size_mb > 0 else 0, |
| rss_growth_regular / 1024**2, |
| rss_growth_regular / (size_mb * 1024**2) if size_mb > 0 else 0, |
| ) |
|
|
| if rank == 0: |
| logger.info("PASSED: test_pinned_memory_rss_overhead") |
|
|
|
|
| def main(): |
| rank, world_size = _setup() |
|
|
| try: |
| test_pinned_memory_rss_overhead(rank, world_size) |
| test_cpu_memory_peak_detailed(rank, world_size) |
| test_offload_cpu_cost_isolation(rank, world_size) |
| test_offload_cpu_cost_mixed(rank, world_size) |
|
|
| if rank == 0: |
| logger.info("=" * 50) |
| logger.info("ALL CPU MEMORY PEAK TESTS PASSED") |
| logger.info("=" * 50) |
| finally: |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|