import torch import torch.nn as nn from omini.rotation.layer import Linear, Rotation def test_rotation_merge(): """ Test that merging rotation adapter produces the same output as the unmerged version. """ print("="*60) print("Testing Rotation Layer Merge") print("="*60) # Set random seed for reproducibility torch.manual_seed(42) # Configuration in_features = 512 out_features = 1024 r = 4 num_rotations = 4 T = 1.0 batch_size = 8 seq_len = 16 # Create base linear layer base_layer = nn.Linear(in_features, out_features, bias=True) # Create rotation layer rotation_layer = Linear( base_layer=base_layer, adapter_name="default", r=r, T=T, num_rotations=num_rotations ) # Create random input x = torch.randn(batch_size, seq_len, in_features) # Test 1: Forward pass before merge print("\n" + "-"*60) print("Test 1: Computing output BEFORE merge") print("-"*60) rotation_layer.eval() with torch.no_grad(): output_before = rotation_layer(x) print(f"Output shape: {output_before.shape}") print(f"Output mean: {output_before.mean().item():.6f}") print(f"Output std: {output_before.std().item():.6f}") print(f"Output min: {output_before.min().item():.6f}") print(f"Output max: {output_before.max().item():.6f}") # Save original weight for verification original_weight = base_layer.weight.data.clone() # Test 2: Merge adapter print("\n" + "-"*60) print("Test 2: Merging adapter") print("-"*60) rotation_layer.merge(safe_merge=True, adapter_names=["default"]) print(f"✓ Adapter merged successfully") print(f"✓ Merged adapters: {rotation_layer.merged_adapters}") # Check that weights have changed weight_diff = (base_layer.weight.data - original_weight).abs().max().item() print(f"Max weight change: {weight_diff:.6e}") # Test 3: Forward pass after merge print("\n" + "-"*60) print("Test 3: Computing output AFTER merge") print("-"*60) with torch.no_grad(): output_after = rotation_layer(x) print(f"Output shape: {output_after.shape}") print(f"Output mean: {output_after.mean().item():.6f}") print(f"Output std: {output_after.std().item():.6f}") print(f"Output min: {output_after.min().item():.6f}") print(f"Output max: {output_after.max().item():.6f}") # Test 4: Compare outputs print("\n" + "-"*60) print("Test 4: Comparing outputs") print("-"*60) # Compute differences abs_diff = (output_after - output_before).abs() rel_diff = abs_diff / (output_before.abs() + 1e-8) max_abs_diff = abs_diff.max().item() mean_abs_diff = abs_diff.mean().item() max_rel_diff = rel_diff.max().item() mean_rel_diff = rel_diff.mean().item() print(f"Max absolute difference: {max_abs_diff:.6e}") print(f"Mean absolute difference: {mean_abs_diff:.6e}") print(f"Max relative difference: {max_rel_diff:.6e}") print(f"Mean relative difference: {mean_rel_diff:.6e}") # Check if outputs are close atol = 1e-4 # Absolute tolerance rtol = 1e-3 # Relative tolerance are_close = torch.allclose(output_before, output_after, atol=atol, rtol=rtol) if are_close: print(f"\n✅ PASS: Outputs are identical (within atol={atol}, rtol={rtol})") else: print(f"\n❌ FAIL: Outputs differ significantly") print(f" Expected: atol < {atol}, rtol < {rtol}") print(f" Got: max_abs_diff = {max_abs_diff:.6e}, max_rel_diff = {max_rel_diff:.6e}") # Test 5: Unmerge and verify print("\n" + "-"*60) print("Test 5: Testing unmerge") print("-"*60) rotation_layer.unmerge() print(f"✓ Adapter unmerged") print(f"✓ Merged adapters: {rotation_layer.merged_adapters}") with torch.no_grad(): output_unmerged = rotation_layer(x) unmerge_diff = (output_unmerged - output_before).abs().max().item() print(f"Max difference after unmerge: {unmerge_diff:.6e}") unmerge_close = torch.allclose(output_before, output_unmerged, atol=atol, rtol=rtol) if unmerge_close: print(f"✅ PASS: Unmerge restored original behavior") else: print(f"❌ FAIL: Unmerge did not restore original behavior") # Test 6: Verify weight restoration weight_restored_diff = (base_layer.weight.data - original_weight).abs().max().item() print(f"Max weight difference after unmerge: {weight_restored_diff:.6e}") weight_restored = torch.allclose(base_layer.weight.data, original_weight, atol=1e-5) if weight_restored: print(f"✅ PASS: Original weights restored") else: print(f"❌ FAIL: Original weights not fully restored") print("\n" + "="*60) print("Test Summary") print("="*60) return are_close and unmerge_close and weight_restored def test_multiple_merges(): """ Test merging and unmerging multiple times. """ print("\n" + "="*60) print("Testing Multiple Merge/Unmerge Cycles") print("="*60) torch.manual_seed(42) in_features = 256 out_features = 512 r = 4 num_rotations = 4 base_layer = nn.Linear(in_features, out_features, bias=True) rotation_layer = Linear( base_layer=base_layer, adapter_name="default", r=r, T=1.0, num_rotations=num_rotations ) x = torch.randn(4, 8, in_features) rotation_layer.eval() # Get original output with torch.no_grad(): original_output = rotation_layer(x) # Test multiple cycles all_passed = True for cycle in range(3): print(f"\nCycle {cycle + 1}:") # Merge rotation_layer.merge(safe_merge=True) with torch.no_grad(): merged_output = rotation_layer(x) merge_close = torch.allclose(original_output, merged_output, atol=1e-4, rtol=1e-3) print(f" Merge: {'✅ PASS' if merge_close else '❌ FAIL'}") # Unmerge rotation_layer.unmerge() with torch.no_grad(): unmerged_output = rotation_layer(x) unmerge_close = torch.allclose(original_output, unmerged_output, atol=1e-4, rtol=1e-3) print(f" Unmerge: {'✅ PASS' if unmerge_close else '❌ FAIL'}") all_passed = all_passed and merge_close and unmerge_close return all_passed def test_with_different_dtypes(): """ Test merging with different data types. """ print("\n" + "="*60) print("Testing Different Data Types") print("="*60) torch.manual_seed(42) dtypes = [torch.float32, torch.float16, torch.bfloat16] all_passed = True for dtype in dtypes: print(f"\nTesting with dtype: {dtype}") in_features = 256 out_features = 512 r = 4 num_rotations = 4 base_layer = nn.Linear(in_features, out_features, bias=True) base_layer = base_layer.to(dtype) rotation_layer = Linear( base_layer=base_layer, adapter_name="default", r=r, T=1.0, num_rotations=num_rotations ) rotation_layer = rotation_layer.to(dtype) x = torch.randn(4, 8, in_features, dtype=dtype) rotation_layer.eval() with torch.no_grad(): output_before = rotation_layer(x) rotation_layer.merge(safe_merge=True) output_after = rotation_layer(x) # Adjust tolerances based on dtype if dtype == torch.float32: atol, rtol = 1e-5, 1e-4 elif dtype == torch.float16: atol, rtol = 1e-2, 1e-2 else: # bfloat16 atol, rtol = 1e-2, 1e-2 are_close = torch.allclose(output_before, output_after, atol=atol, rtol=rtol) if are_close: print(f" ✅ PASS") else: max_diff = (output_after - output_before).abs().max().item() print(f" ❌ FAIL (max diff: {max_diff:.6e})") all_passed = all_passed and are_close return all_passed if __name__ == "__main__": print("\n" + "="*60) print("ROTATION LAYER MERGE TEST SUITE") print("="*60) results = {} # Run all tests results["basic_merge"] = test_rotation_merge() results["multiple_cycles"] = test_multiple_merges() results["different_dtypes"] = test_with_different_dtypes() # Print summary print("\n" + "="*60) print("FINAL SUMMARY") print("="*60) for test_name, passed in results.items(): status = "✅ PASS" if passed else "❌ FAIL" print(f"{test_name}: {status}") all_passed = all(results.values()) print("\n" + "="*60) if all_passed: print("🎉 ALL TESTS PASSED!") else: print("⚠️ SOME TESTS FAILED") print("="*60)