| | |
| | """ |
| | Test WrinkleBrane Optimizations |
| | Validate performance and fidelity improvements from optimizations. |
| | """ |
| |
|
| | import sys |
| | from pathlib import Path |
| | sys.path.append(str(Path(__file__).resolve().parent / "src")) |
| |
|
| | import torch |
| | import numpy as np |
| | import time |
| | from wrinklebrane.membrane_bank import MembraneBank |
| | from wrinklebrane.codes import hadamard_codes |
| | from wrinklebrane.slicer import make_slicer |
| | from wrinklebrane.write_ops import store_pairs |
| | from wrinklebrane.metrics import psnr, ssim |
| | from wrinklebrane.optimizations import ( |
| | compute_adaptive_alphas, |
| | generate_extended_codes, |
| | HierarchicalMembraneBank, |
| | optimized_store_pairs |
| | ) |
| |
|
| | def test_adaptive_alphas(): |
| | """Test adaptive alpha scaling vs uniform alphas.""" |
| | print("π§ͺ Testing Adaptive Alpha Scaling...") |
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | B, L, H, W, K = 1, 32, 16, 16, 8 |
| | |
| | |
| | bank_uniform = MembraneBank(L, H, W, device=device) |
| | bank_adaptive = MembraneBank(L, H, W, device=device) |
| | bank_uniform.allocate(B) |
| | bank_adaptive.allocate(B) |
| | |
| | C = hadamard_codes(L, K).to(device) |
| | slicer = make_slicer(C) |
| | |
| | |
| | patterns = [] |
| | for i in range(K): |
| | pattern = torch.zeros(H, W, device=device) |
| | |
| | energy_scale = 0.1 + i * 0.3 |
| | |
| | if i % 3 == 0: |
| | for y in range(H): |
| | for x in range(W): |
| | if (x - H//2)**2 + (y - W//2)**2 <= (3 + i//3)**2: |
| | pattern[y, x] = energy_scale |
| | elif i % 3 == 1: |
| | size = 4 + i//3 |
| | start = (H - size) // 2 |
| | pattern[start:start+size, start:start+size] = energy_scale * 0.5 |
| | else: |
| | for d in range(min(H, W)): |
| | if d + i//3 < H and d + i//3 < W: |
| | pattern[d + i//3, d] = energy_scale * 0.1 |
| | |
| | patterns.append(pattern) |
| | |
| | patterns = torch.stack(patterns) |
| | keys = torch.arange(K, device=device) |
| | |
| | |
| | uniform_alphas = torch.ones(K, device=device) |
| | M_uniform = store_pairs(bank_uniform.read(), C, keys, patterns, uniform_alphas) |
| | bank_uniform.write(M_uniform - bank_uniform.read()) |
| | uniform_readouts = slicer(bank_uniform.read()).squeeze(0) |
| | |
| | |
| | adaptive_alphas = compute_adaptive_alphas(patterns, C, keys) |
| | M_adaptive = store_pairs(bank_adaptive.read(), C, keys, patterns, adaptive_alphas) |
| | bank_adaptive.write(M_adaptive - bank_adaptive.read()) |
| | adaptive_readouts = slicer(bank_adaptive.read()).squeeze(0) |
| | |
| | |
| | uniform_psnr = [] |
| | adaptive_psnr = [] |
| | |
| | print(" Pattern-by-pattern comparison:") |
| | for i in range(K): |
| | u_psnr = psnr(patterns[i].cpu().numpy(), uniform_readouts[i].cpu().numpy()) |
| | a_psnr = psnr(patterns[i].cpu().numpy(), adaptive_readouts[i].cpu().numpy()) |
| | |
| | uniform_psnr.append(u_psnr) |
| | adaptive_psnr.append(a_psnr) |
| | |
| | energy = torch.norm(patterns[i]).item() |
| | print(f" Pattern {i}: Energy={energy:.3f}, Alpha={adaptive_alphas[i]:.3f}") |
| | print(f" Uniform PSNR: {u_psnr:.1f}dB, Adaptive PSNR: {a_psnr:.1f}dB") |
| | |
| | avg_uniform = np.mean(uniform_psnr) |
| | avg_adaptive = np.mean(adaptive_psnr) |
| | improvement = avg_adaptive - avg_uniform |
| | |
| | print(f"\n Results Summary:") |
| | print(f" Uniform alphas: {avg_uniform:.1f}dB average PSNR") |
| | print(f" Adaptive alphas: {avg_adaptive:.1f}dB average PSNR") |
| | print(f" Improvement: {improvement:.1f}dB ({improvement/avg_uniform*100:.1f}%)") |
| | |
| | return improvement > 0 |
| |
|
| |
|
| | def test_extended_codes(): |
| | """Test extended code generation for K > L scenarios.""" |
| | print("\nπ§ͺ Testing Extended Code Generation...") |
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | L = 32 |
| | test_Ks = [16, 32, 64, 128] |
| | |
| | results = {} |
| | |
| | for K in test_Ks: |
| | print(f" Testing L={L}, K={K} (capacity: {K/L:.1f}x)") |
| | |
| | |
| | C = generate_extended_codes(L, K, method="auto", device=device) |
| | |
| | |
| | if K <= L: |
| | G = C.T @ C |
| | I_approx = torch.eye(K, device=device, dtype=C.dtype) |
| | orthogonality_error = torch.norm(G - I_approx).item() |
| | else: |
| | |
| | C_ortho = C[:, :L] |
| | G = C_ortho.T @ C_ortho |
| | I_approx = torch.eye(L, device=device, dtype=C.dtype) |
| | orthogonality_error = torch.norm(G - I_approx).item() |
| | |
| | |
| | B, H, W = 1, 8, 8 |
| | bank = MembraneBank(L, H, W, device=device) |
| | bank.allocate(B) |
| | |
| | slicer = make_slicer(C) |
| | |
| | |
| | |
| | actual_K = min(K, C.shape[1]) |
| | patterns = torch.rand(actual_K, H, W, device=device) |
| | keys = torch.arange(actual_K, device=device) |
| | alphas = torch.ones(actual_K, device=device) |
| | |
| | |
| | M = store_pairs(bank.read(), C, keys, patterns, alphas) |
| | bank.write(M - bank.read()) |
| | readouts = slicer(bank.read()).squeeze(0) |
| | |
| | |
| | psnr_values = [] |
| | for i in range(actual_K): |
| | psnr_val = psnr(patterns[i].cpu().numpy(), readouts[i].cpu().numpy()) |
| | psnr_values.append(psnr_val) |
| | |
| | avg_psnr = np.mean(psnr_values) |
| | min_psnr = np.min(psnr_values) |
| | std_psnr = np.std(psnr_values) |
| | |
| | results[K] = { |
| | "orthogonality_error": orthogonality_error, |
| | "avg_psnr": avg_psnr, |
| | "min_psnr": min_psnr, |
| | "std_psnr": std_psnr |
| | } |
| | |
| | print(f" Orthogonality error: {orthogonality_error:.6f}") |
| | print(f" PSNR: {avg_psnr:.1f}Β±{std_psnr:.1f}dB (min: {min_psnr:.1f}dB)") |
| | |
| | return results |
| |
|
| |
|
| | def test_hierarchical_memory(): |
| | """Test hierarchical memory bank organization.""" |
| | print("\nπ§ͺ Testing Hierarchical Memory Bank...") |
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | L, H, W = 64, 32, 32 |
| | K = 32 |
| | |
| | |
| | hierarchical_bank = HierarchicalMembraneBank(L, H, W, levels=3, device=device) |
| | hierarchical_bank.allocate(1) |
| | |
| | |
| | regular_bank = MembraneBank(L, H, W, device=device) |
| | regular_bank.allocate(1) |
| | |
| | |
| | patterns = [] |
| | for i in range(K): |
| | if i < K // 3: |
| | pattern = torch.rand(H, W, device=device) |
| | elif i < 2 * K // 3: |
| | pattern = torch.zeros(H, W, device=device) |
| | pattern[H//4:3*H//4, W//4:3*W//4] = torch.rand(H//2, W//2, device=device) |
| | else: |
| | pattern = torch.zeros(H, W, device=device) |
| | pattern[H//2-2:H//2+2, W//2-2:W//2+2] = torch.ones(4, 4, device=device) |
| | patterns.append(pattern) |
| | |
| | patterns = torch.stack(patterns) |
| | keys = torch.arange(K, device=device) |
| | |
| | |
| | C_regular = hadamard_codes(L, K).to(device) |
| | slicer_regular = make_slicer(C_regular) |
| | alphas_regular = torch.ones(K, device=device) |
| | |
| | start_time = time.time() |
| | M_regular = store_pairs(regular_bank.read(), C_regular, keys, patterns, alphas_regular) |
| | regular_bank.write(M_regular - regular_bank.read()) |
| | regular_readouts = slicer_regular(regular_bank.read()).squeeze(0) |
| | regular_time = time.time() - start_time |
| | |
| | |
| | start_time = time.time() |
| | hierarchical_bank.store_hierarchical(patterns, keys) |
| | hierarchical_time = time.time() - start_time |
| | |
| | |
| | regular_memory = L * H * W * 4 |
| | hierarchical_memory = sum(bank.L * H * W * 4 for bank in hierarchical_bank.banks) |
| | memory_savings = (regular_memory - hierarchical_memory) / regular_memory * 100 |
| | |
| | |
| | regular_psnr = [] |
| | for i in range(K): |
| | psnr_val = psnr(patterns[i].cpu().numpy(), regular_readouts[i].cpu().numpy()) |
| | regular_psnr.append(psnr_val) |
| | |
| | avg_regular_psnr = np.mean(regular_psnr) |
| | |
| | print(f" Regular Bank:") |
| | print(f" Storage time: {regular_time*1000:.2f}ms") |
| | print(f" Memory usage: {regular_memory/1e6:.2f}MB") |
| | print(f" Average PSNR: {avg_regular_psnr:.1f}dB") |
| | |
| | print(f" Hierarchical Bank:") |
| | print(f" Storage time: {hierarchical_time*1000:.2f}ms") |
| | print(f" Memory usage: {hierarchical_memory/1e6:.2f}MB") |
| | print(f" Memory savings: {memory_savings:.1f}%") |
| | print(f" Levels: {hierarchical_bank.levels}") |
| | |
| | for i, bank in enumerate(hierarchical_bank.banks): |
| | level_fraction = bank.L / hierarchical_bank.total_L |
| | print(f" Level {i}: L={bank.L} ({level_fraction:.1%})") |
| | |
| | return memory_savings > 0 |
| |
|
| |
|
| | def test_optimized_storage(): |
| | """Test the complete optimized storage pipeline.""" |
| | print("\nπ§ͺ Testing Optimized Storage Pipeline...") |
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | B, L, H, W, K = 1, 64, 32, 32, 48 |
| | |
| | |
| | bank_original = MembraneBank(L, H, W, device=device) |
| | bank_optimized = MembraneBank(L, H, W, device=device) |
| | bank_original.allocate(B) |
| | bank_optimized.allocate(B) |
| | |
| | |
| | C = generate_extended_codes(L, K, method="auto", device=device) |
| | slicer = make_slicer(C) |
| | |
| | |
| | patterns = [] |
| | for i in range(K): |
| | if i % 4 == 0: |
| | pattern = torch.rand(H, W, device=device) * 2.0 |
| | elif i % 4 == 1: |
| | pattern = torch.rand(H, W, device=device) * 1.0 |
| | elif i % 4 == 2: |
| | pattern = torch.rand(H, W, device=device) * 0.5 |
| | else: |
| | pattern = torch.zeros(H, W, device=device) |
| | pattern[torch.rand(H, W, device=device) > 0.95] = torch.rand((torch.rand(H, W, device=device) > 0.95).sum(), device=device) |
| | patterns.append(pattern) |
| | |
| | patterns = torch.stack(patterns) |
| | keys = torch.arange(K, device=device) |
| | |
| | |
| | start_time = time.time() |
| | alphas_original = torch.ones(K, device=device) |
| | M_original = store_pairs(bank_original.read(), C, keys, patterns, alphas_original) |
| | bank_original.write(M_original - bank_original.read()) |
| | original_readouts = slicer(bank_original.read()).squeeze(0) |
| | original_time = time.time() - start_time |
| | |
| | |
| | start_time = time.time() |
| | M_optimized = optimized_store_pairs( |
| | bank_optimized.read(), C, keys, patterns, |
| | adaptive_alphas=True, sparsity_threshold=0.01 |
| | ) |
| | bank_optimized.write(M_optimized - bank_optimized.read()) |
| | optimized_readouts = slicer(bank_optimized.read()).squeeze(0) |
| | optimized_time = time.time() - start_time |
| | |
| | |
| | original_psnr = [] |
| | optimized_psnr = [] |
| | |
| | for i in range(K): |
| | o_psnr = psnr(patterns[i].cpu().numpy(), original_readouts[i].cpu().numpy()) |
| | opt_psnr = psnr(patterns[i].cpu().numpy(), optimized_readouts[i].cpu().numpy()) |
| | |
| | original_psnr.append(o_psnr) |
| | optimized_psnr.append(opt_psnr) |
| | |
| | avg_original = np.mean(original_psnr) |
| | avg_optimized = np.mean(optimized_psnr) |
| | fidelity_improvement = avg_optimized - avg_original |
| | speed_improvement = (original_time - optimized_time) / original_time * 100 |
| | |
| | print(f" Original Pipeline:") |
| | print(f" Time: {original_time*1000:.2f}ms") |
| | print(f" Average PSNR: {avg_original:.1f}dB") |
| | |
| | print(f" Optimized Pipeline:") |
| | print(f" Time: {optimized_time*1000:.2f}ms") |
| | print(f" Average PSNR: {avg_optimized:.1f}dB") |
| | |
| | print(f" Improvements:") |
| | print(f" Fidelity: +{fidelity_improvement:.1f}dB ({fidelity_improvement/avg_original*100:.1f}%)") |
| | print(f" Speed: {speed_improvement:.1f}% {'faster' if speed_improvement > 0 else 'slower'}") |
| | |
| | return fidelity_improvement > 0 |
| |
|
| |
|
| | def main(): |
| | """Run complete optimization test suite.""" |
| | print("π WrinkleBrane Optimization Test Suite") |
| | print("="*50) |
| | |
| | |
| | torch.manual_seed(42) |
| | np.random.seed(42) |
| | |
| | success_count = 0 |
| | total_tests = 4 |
| | |
| | try: |
| | |
| | if test_adaptive_alphas(): |
| | print("β
Adaptive alpha scaling: IMPROVED PERFORMANCE") |
| | success_count += 1 |
| | else: |
| | print("β οΈ Adaptive alpha scaling: NO IMPROVEMENT") |
| | |
| | |
| | extended_results = test_extended_codes() |
| | if all(r['avg_psnr'] > 50 for r in extended_results.values()): |
| | print("β
Extended code generation: WORKING") |
| | success_count += 1 |
| | else: |
| | print("β οΈ Extended code generation: QUALITY ISSUES") |
| | |
| | |
| | if test_hierarchical_memory(): |
| | print("β
Hierarchical memory: MEMORY SAVINGS") |
| | success_count += 1 |
| | else: |
| | print("β οΈ Hierarchical memory: NO SAVINGS") |
| | |
| | |
| | if test_optimized_storage(): |
| | print("β
Optimized storage pipeline: IMPROVED FIDELITY") |
| | success_count += 1 |
| | else: |
| | print("β οΈ Optimized storage pipeline: NO IMPROVEMENT") |
| | |
| | print("\n" + "="*50) |
| | print(f"π― Optimization Results: {success_count}/{total_tests} improvements successful") |
| | |
| | if success_count == total_tests: |
| | print("π ALL OPTIMIZATIONS WORKING PERFECTLY!") |
| | elif success_count > total_tests // 2: |
| | print("β
MAJORITY OF OPTIMIZATIONS SUCCESSFUL") |
| | else: |
| | print("β οΈ Mixed results - some optimizations need work") |
| | |
| | except Exception as e: |
| | print(f"\nβ Optimization tests failed with error: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return False |
| | |
| | return success_count > 0 |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |