"""CPU memory peak vs offloaded tensor size verification. Compares CPU memory usage with turn_on_cpu_offload() vs no offload to isolate the actual CPU cost of offloading, separating it from CUDA runtime, NCCL, and DTensor overhead. Run with: torchrun --nproc-per-node=8 --local-ranks-filter=0 test/test_cpu_memory_peak.py """ import gc import logging import os import torch import torch.distributed as dist from torch.distributed.tensor import DTensor, Shard, distribute_tensor logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s") def _setup(): dist.init_process_group(backend="nccl") rank = dist.get_rank() torch.cuda.set_device(rank % torch.cuda.device_count()) return rank, dist.get_world_size() def _make_mesh(world_size): return dist.init_device_mesh("cuda", (world_size, ), mesh_dim_names=("dp", )) def get_cpu_rss_bytes(): """Get current process RSS in bytes from /proc/self/statm.""" with open("/proc/self/statm") as f: pages = int(f.read().split()[1]) return pages * os.sysconf("SC_PAGE_SIZE") def get_pinned_pool_bytes(pool): """Get total pinned CPU buffer size from CPUOffloadPool.""" total = 0 for grp in pool._groups.values(): cpu_flat = grp["cpu_flat"] total += cpu_flat.numel() * cpu_flat.element_size() return total def _run_muon_steps(mesh, dim0, dim1, num_params, num_steps, cpu_offload): """Run Muon optimizer steps and return final CPU RSS.""" from optimizer.muon import Muon from optimizer.newton_schulz import set_ns_compile set_ns_compile(False) torch.manual_seed(42) gc.collect() torch.cuda.empty_cache() params, names = [], [] for i in range(num_params): full = torch.randn(dim0, dim1, device="cuda") dt = distribute_tensor(full, mesh, [Shard(0)]) p = torch.nn.Parameter(dt) params.append(p) names.append(f"layer.{i}.weight") param_groups = [{ "params": params, "names": names, "use_muon": True, "lr": 0.02, "weight_decay": 0.01, "momentum": 0.95, "nesterov": True, "ns_steps": 5, "none_grad": False, }] optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) if cpu_offload: optim.turn_on_cpu_offload() for step_idx in range(num_steps): for p in params: p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)]) optim.step() torch.cuda.synchronize() gc.collect() cpu_rss = get_cpu_rss_bytes() pinned_bytes = 0 if cpu_offload and optim._cpu_offload_pool is not None: pool = optim._cpu_offload_pool pinned_bytes = get_pinned_pool_bytes(pool) # Cleanup. del optim, params, param_groups gc.collect() torch.cuda.empty_cache() set_ns_compile(True) return cpu_rss, pinned_bytes def test_offload_cpu_cost_isolation(rank, world_size): """A/B test: measure CPU cost of offload by comparing ON vs OFF.""" mesh = _make_mesh(world_size) dim0, dim1 = 2048, 4096 num_params = 8 num_steps = 3 if rank == 0: logger.info("=" * 70) logger.info("A/B TEST: CPU MEMORY COST OF OFFLOAD (ON vs OFF)") logger.info("=" * 70) logger.info("Config: %d params of shape (%d, %d), %d ranks, %d steps", num_params, dim0, dim1, world_size, num_steps) logger.info("Local param shape per rank: (%d, %d)", dim0 // world_size, dim1) logger.info("-" * 70) # Run WITHOUT offload first (baseline). gc.collect() torch.cuda.empty_cache() cpu_before_no_offload = get_cpu_rss_bytes() cpu_after_no_offload, _ = _run_muon_steps(mesh, dim0, dim1, num_params, num_steps, cpu_offload=False) cpu_growth_no_offload = cpu_after_no_offload - cpu_before_no_offload # Run WITH offload. gc.collect() torch.cuda.empty_cache() cpu_before_offload = get_cpu_rss_bytes() cpu_after_offload, pinned_bytes = _run_muon_steps(mesh, dim0, dim1, num_params, num_steps, cpu_offload=True) cpu_growth_offload = cpu_after_offload - cpu_before_offload # Delta = additional CPU cost from offloading. offload_delta = cpu_growth_offload - cpu_growth_no_offload if rank == 0: logger.info("CPU growth WITHOUT offload: %.2f MB", cpu_growth_no_offload / 1024**2) logger.info("CPU growth WITH offload: %.2f MB", cpu_growth_offload / 1024**2) logger.info("-" * 70) logger.info("Pinned buffer size (expected): %.2f MB", pinned_bytes / 1024**2) logger.info("Offload delta (WITH - WITHOUT): %.2f MB", offload_delta / 1024**2) if pinned_bytes > 0: ratio = offload_delta / pinned_bytes logger.info("Ratio (delta / pinned buffer): %.2fx", ratio) if ratio > 1.5: logger.warning( "Offload adds %.2f MB CPU memory but pinned buffer is " "only %.2f MB (%.1f%% overhead beyond expected)", offload_delta / 1024**2, pinned_bytes / 1024**2, (offload_delta - pinned_bytes) / pinned_bytes * 100, ) else: logger.info("Offload CPU cost is within expected range.") # Only assert on rank 0 to avoid multi-rank assertion mismatches. if rank == 0 and pinned_bytes > 0: ratio = offload_delta / pinned_bytes assert ratio < 3.0, ( f"Offload CPU cost ({offload_delta / 1024**2:.2f} MB) is " f"{ratio:.2f}x the pinned buffer ({pinned_bytes / 1024**2:.2f} MB). " f"Expected < 3.0x.") if rank == 0: logger.info("PASSED: test_offload_cpu_cost_isolation") def test_cpu_memory_peak_detailed(rank, world_size): """Detailed per-phase CPU memory tracking for offload.""" from optimizer.muon import Muon from optimizer.newton_schulz import set_ns_compile set_ns_compile(False) torch.manual_seed(42) mesh = _make_mesh(world_size) dim0, dim1 = 2048, 4096 num_params = 8 gc.collect() torch.cuda.empty_cache() if rank == 0: logger.info("=" * 70) logger.info("DETAILED PER-PHASE CPU MEMORY TRACKING") logger.info("=" * 70) cpu_0 = get_cpu_rss_bytes() if rank == 0: logger.info("[Phase 0] Baseline RSS: %.2f MB", cpu_0 / 1024**2) # Create params. params, names = [], [] for i in range(num_params): full = torch.randn(dim0, dim1, device="cuda") dt = distribute_tensor(full, mesh, [Shard(0)]) p = torch.nn.Parameter(dt) params.append(p) names.append(f"layer.{i}.weight") gc.collect() cpu_1 = get_cpu_rss_bytes() if rank == 0: logger.info("[Phase 1] After param creation: %.2f MB (+%.2f MB)", cpu_1 / 1024**2, (cpu_1 - cpu_0) / 1024**2) # Create optimizer. param_groups = [{ "params": params, "names": names, "use_muon": True, "lr": 0.02, "weight_decay": 0.01, "momentum": 0.95, "nesterov": True, "ns_steps": 5, "none_grad": False, }] optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) optim.turn_on_cpu_offload() gc.collect() cpu_2 = get_cpu_rss_bytes() if rank == 0: logger.info("[Phase 2] After optimizer creation: %.2f MB (+%.2f MB)", cpu_2 / 1024**2, (cpu_2 - cpu_1) / 1024**2) # Set grads. for p in params: p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)]) gc.collect() cpu_3 = get_cpu_rss_bytes() if rank == 0: logger.info("[Phase 3] After grad creation: %.2f MB (+%.2f MB)", cpu_3 / 1024**2, (cpu_3 - cpu_2) / 1024**2) # Step 1 (creates states + first offload). optim.step() torch.cuda.synchronize() gc.collect() cpu_4 = get_cpu_rss_bytes() pool = optim._cpu_offload_pool pinned_bytes = get_pinned_pool_bytes(pool) if rank == 0: logger.info( "[Phase 4] After step 1 (init+offload): %.2f MB (+%.2f MB)", cpu_4 / 1024**2, (cpu_4 - cpu_3) / 1024**2) logger.info(" Pinned buffer size: %.2f MB", pinned_bytes / 1024**2) logger.info(" Step 1 growth vs pinned: %.2f MB extra", (cpu_4 - cpu_3 - pinned_bytes) / 1024**2) # Step 2 (reload + compute + offload). for p in params: p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)]) optim.step() torch.cuda.synchronize() gc.collect() cpu_5 = get_cpu_rss_bytes() if rank == 0: logger.info("[Phase 5] After step 2: %.2f MB (+%.2f MB)", cpu_5 / 1024**2, (cpu_5 - cpu_4) / 1024**2) # Step 3. for p in params: p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)]) optim.step() torch.cuda.synchronize() gc.collect() cpu_6 = get_cpu_rss_bytes() if rank == 0: logger.info("[Phase 6] After step 3: %.2f MB (+%.2f MB)", cpu_6 / 1024**2, (cpu_6 - cpu_5) / 1024**2) # Summary. total_growth = cpu_6 - cpu_0 if rank == 0: logger.info("-" * 70) logger.info("SUMMARY:") logger.info(" Total CPU growth: %.2f MB", total_growth / 1024**2) logger.info(" Pinned buffer: %.2f MB", pinned_bytes / 1024**2) logger.info(" Overhead: %.2f MB", (total_growth - pinned_bytes) / 1024**2) if pinned_bytes > 0: logger.info(" Ratio: %.2fx", total_growth / pinned_bytes) logger.info("") logger.info(" NOTE: Overhead includes CUDA runtime, NCCL buffers,") logger.info(" DTensor metadata, and optimizer internals — NOT just") logger.info(" offload cost. Use A/B test for isolated measurement.") set_ns_compile(True) if rank == 0: logger.info("PASSED: test_cpu_memory_peak_detailed") def test_offload_cpu_cost_mixed(rank, world_size): """A/B test for mixed Muon + AdamW offload CPU cost.""" from optimizer.muon import Muon from optimizer.newton_schulz import set_ns_compile mesh = _make_mesh(world_size) muon_dim0, muon_dim1 = 2048, 4096 num_muon = 8 adamw_dim = 4096 num_adamw = 8 num_steps = 3 def run_mixed(cpu_offload): set_ns_compile(False) torch.manual_seed(42) gc.collect() torch.cuda.empty_cache() muon_params, muon_names = [], [] for i in range(num_muon): full = torch.randn(muon_dim0, muon_dim1, device="cuda") dt = distribute_tensor(full, mesh, [Shard(0)]) p = torch.nn.Parameter(dt) muon_params.append(p) muon_names.append(f"layer.{i}.weight") adamw_params = [] for i in range(num_adamw): full = torch.randn(adamw_dim, device="cuda") dt = distribute_tensor(full, mesh, [Shard(0)]) p = torch.nn.Parameter(dt) adamw_params.append(p) param_groups = [ { "params": muon_params, "names": muon_names, "use_muon": True, "lr": 0.02, "weight_decay": 0.01, "momentum": 0.95, "nesterov": True, "ns_steps": 5, "none_grad": False, "adamw_betas": (0.9, 0.95), "adamw_eps": 1e-8, }, { "params": adamw_params, "use_muon": False, "lr": 1e-3, "weight_decay": 0.01, "adamw_betas": (0.9, 0.95), "adamw_eps": 1e-8, }, ] optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) if cpu_offload: optim.turn_on_cpu_offload() for step_idx in range(num_steps): for p in muon_params: p.grad = distribute_tensor( torch.randn(muon_dim0, muon_dim1, device="cuda"), mesh, [Shard(0)]) for p in adamw_params: p.grad = distribute_tensor( torch.randn(adamw_dim, device="cuda"), mesh, [Shard(0)]) optim.step() torch.cuda.synchronize() gc.collect() cpu_rss = get_cpu_rss_bytes() pinned_bytes = 0 if cpu_offload and optim._cpu_offload_pool is not None: pinned_bytes = get_pinned_pool_bytes(optim._cpu_offload_pool) del optim, muon_params, adamw_params, param_groups gc.collect() torch.cuda.empty_cache() set_ns_compile(True) return cpu_rss, pinned_bytes if rank == 0: logger.info("=" * 70) logger.info("A/B TEST: CPU COST OF MIXED OFFLOAD (Muon + AdamW)") logger.info("=" * 70) gc.collect() torch.cuda.empty_cache() cpu_before_no = get_cpu_rss_bytes() cpu_after_no, _ = run_mixed(False) growth_no = cpu_after_no - cpu_before_no gc.collect() torch.cuda.empty_cache() cpu_before_yes = get_cpu_rss_bytes() cpu_after_yes, pinned_bytes = run_mixed(True) growth_yes = cpu_after_yes - cpu_before_yes delta = growth_yes - growth_no if rank == 0: logger.info("CPU growth WITHOUT offload: %.2f MB", growth_no / 1024**2) logger.info("CPU growth WITH offload: %.2f MB", growth_yes / 1024**2) logger.info("Pinned buffer size: %.2f MB", pinned_bytes / 1024**2) logger.info("Offload delta: %.2f MB", delta / 1024**2) if pinned_bytes > 0: logger.info("Ratio (delta / pinned): %.2fx", delta / pinned_bytes) if rank == 0 and pinned_bytes > 0: ratio = delta / pinned_bytes assert ratio < 3.0, ( f"Mixed offload CPU cost ({delta / 1024**2:.2f} MB) is " f"{ratio:.2f}x the pinned buffer ({pinned_bytes / 1024**2:.2f} MB)." ) if rank == 0: logger.info("PASSED: test_offload_cpu_cost_mixed") def test_pinned_memory_rss_overhead(rank, world_size): """Isolate: does cudaHostAlloc itself cause 2x RSS overhead?""" sizes_mb = [8, 16, 32, 64, 128] if rank == 0: logger.info("=" * 70) logger.info("ISOLATED TEST: PINNED MEMORY RSS OVERHEAD") logger.info("=" * 70) for size_mb in sizes_mb: numel = size_mb * 1024 * 1024 // 4 # float32 # Test 1: pin_memory=True (direct allocation). gc.collect() torch.cuda.empty_cache() rss_before = get_cpu_rss_bytes() t1 = torch.empty(numel, dtype=torch.float32, device="cpu", pin_memory=True) rss_after = get_cpu_rss_bytes() rss_growth_direct = rss_after - rss_before del t1 gc.collect() # Test 2: .pin_memory() (copy-based). gc.collect() torch.cuda.empty_cache() rss_before2 = get_cpu_rss_bytes() t2 = torch.empty(numel, dtype=torch.float32, device="cpu").pin_memory() rss_after2 = get_cpu_rss_bytes() rss_growth_copy = rss_after2 - rss_before2 del t2 gc.collect() # Test 3: regular (non-pinned) CPU allocation. gc.collect() torch.cuda.empty_cache() rss_before3 = get_cpu_rss_bytes() t3 = torch.empty(numel, dtype=torch.float32, device="cpu") # Touch all pages to ensure RSS reflects actual allocation. t3.fill_(1.0) rss_after3 = get_cpu_rss_bytes() rss_growth_regular = rss_after3 - rss_before3 del t3 gc.collect() if rank == 0: logger.info( "%3d MB: pin_memory=True → RSS +%.1f MB (%.2fx) | " ".pin_memory() → RSS +%.1f MB (%.2fx) | " "regular → RSS +%.1f MB (%.2fx)", size_mb, rss_growth_direct / 1024**2, rss_growth_direct / (size_mb * 1024**2) if size_mb > 0 else 0, rss_growth_copy / 1024**2, rss_growth_copy / (size_mb * 1024**2) if size_mb > 0 else 0, rss_growth_regular / 1024**2, rss_growth_regular / (size_mb * 1024**2) if size_mb > 0 else 0, ) if rank == 0: logger.info("PASSED: test_pinned_memory_rss_overhead") def main(): rank, world_size = _setup() try: test_pinned_memory_rss_overhead(rank, world_size) test_cpu_memory_peak_detailed(rank, world_size) test_offload_cpu_cost_isolation(rank, world_size) test_offload_cpu_cost_mixed(rank, world_size) if rank == 0: logger.info("=" * 50) logger.info("ALL CPU MEMORY PEAK TESTS PASSED") logger.info("=" * 50) finally: dist.destroy_process_group() if __name__ == "__main__": main()