nvan15's picture
Batch upload part 19
b816a2c verified
import torch
import torch.nn as nn
from omini.rotation.layer import Linear, Rotation
def test_rotation_merge():
"""
Test that merging rotation adapter produces the same output as the unmerged version.
"""
print("="*60)
print("Testing Rotation Layer Merge")
print("="*60)
# Set random seed for reproducibility
torch.manual_seed(42)
# Configuration
in_features = 512
out_features = 1024
r = 4
num_rotations = 4
T = 1.0
batch_size = 8
seq_len = 16
# Create base linear layer
base_layer = nn.Linear(in_features, out_features, bias=True)
# Create rotation layer
rotation_layer = Linear(
base_layer=base_layer,
adapter_name="default",
r=r,
T=T,
num_rotations=num_rotations
)
# Create random input
x = torch.randn(batch_size, seq_len, in_features)
# Test 1: Forward pass before merge
print("\n" + "-"*60)
print("Test 1: Computing output BEFORE merge")
print("-"*60)
rotation_layer.eval()
with torch.no_grad():
output_before = rotation_layer(x)
print(f"Output shape: {output_before.shape}")
print(f"Output mean: {output_before.mean().item():.6f}")
print(f"Output std: {output_before.std().item():.6f}")
print(f"Output min: {output_before.min().item():.6f}")
print(f"Output max: {output_before.max().item():.6f}")
# Save original weight for verification
original_weight = base_layer.weight.data.clone()
# Test 2: Merge adapter
print("\n" + "-"*60)
print("Test 2: Merging adapter")
print("-"*60)
rotation_layer.merge(safe_merge=True, adapter_names=["default"])
print(f"βœ“ Adapter merged successfully")
print(f"βœ“ Merged adapters: {rotation_layer.merged_adapters}")
# Check that weights have changed
weight_diff = (base_layer.weight.data - original_weight).abs().max().item()
print(f"Max weight change: {weight_diff:.6e}")
# Test 3: Forward pass after merge
print("\n" + "-"*60)
print("Test 3: Computing output AFTER merge")
print("-"*60)
with torch.no_grad():
output_after = rotation_layer(x)
print(f"Output shape: {output_after.shape}")
print(f"Output mean: {output_after.mean().item():.6f}")
print(f"Output std: {output_after.std().item():.6f}")
print(f"Output min: {output_after.min().item():.6f}")
print(f"Output max: {output_after.max().item():.6f}")
# Test 4: Compare outputs
print("\n" + "-"*60)
print("Test 4: Comparing outputs")
print("-"*60)
# Compute differences
abs_diff = (output_after - output_before).abs()
rel_diff = abs_diff / (output_before.abs() + 1e-8)
max_abs_diff = abs_diff.max().item()
mean_abs_diff = abs_diff.mean().item()
max_rel_diff = rel_diff.max().item()
mean_rel_diff = rel_diff.mean().item()
print(f"Max absolute difference: {max_abs_diff:.6e}")
print(f"Mean absolute difference: {mean_abs_diff:.6e}")
print(f"Max relative difference: {max_rel_diff:.6e}")
print(f"Mean relative difference: {mean_rel_diff:.6e}")
# Check if outputs are close
atol = 1e-4 # Absolute tolerance
rtol = 1e-3 # Relative tolerance
are_close = torch.allclose(output_before, output_after, atol=atol, rtol=rtol)
if are_close:
print(f"\nβœ… PASS: Outputs are identical (within atol={atol}, rtol={rtol})")
else:
print(f"\n❌ FAIL: Outputs differ significantly")
print(f" Expected: atol < {atol}, rtol < {rtol}")
print(f" Got: max_abs_diff = {max_abs_diff:.6e}, max_rel_diff = {max_rel_diff:.6e}")
# Test 5: Unmerge and verify
print("\n" + "-"*60)
print("Test 5: Testing unmerge")
print("-"*60)
rotation_layer.unmerge()
print(f"βœ“ Adapter unmerged")
print(f"βœ“ Merged adapters: {rotation_layer.merged_adapters}")
with torch.no_grad():
output_unmerged = rotation_layer(x)
unmerge_diff = (output_unmerged - output_before).abs().max().item()
print(f"Max difference after unmerge: {unmerge_diff:.6e}")
unmerge_close = torch.allclose(output_before, output_unmerged, atol=atol, rtol=rtol)
if unmerge_close:
print(f"βœ… PASS: Unmerge restored original behavior")
else:
print(f"❌ FAIL: Unmerge did not restore original behavior")
# Test 6: Verify weight restoration
weight_restored_diff = (base_layer.weight.data - original_weight).abs().max().item()
print(f"Max weight difference after unmerge: {weight_restored_diff:.6e}")
weight_restored = torch.allclose(base_layer.weight.data, original_weight, atol=1e-5)
if weight_restored:
print(f"βœ… PASS: Original weights restored")
else:
print(f"❌ FAIL: Original weights not fully restored")
print("\n" + "="*60)
print("Test Summary")
print("="*60)
return are_close and unmerge_close and weight_restored
def test_multiple_merges():
"""
Test merging and unmerging multiple times.
"""
print("\n" + "="*60)
print("Testing Multiple Merge/Unmerge Cycles")
print("="*60)
torch.manual_seed(42)
in_features = 256
out_features = 512
r = 4
num_rotations = 4
base_layer = nn.Linear(in_features, out_features, bias=True)
rotation_layer = Linear(
base_layer=base_layer,
adapter_name="default",
r=r,
T=1.0,
num_rotations=num_rotations
)
x = torch.randn(4, 8, in_features)
rotation_layer.eval()
# Get original output
with torch.no_grad():
original_output = rotation_layer(x)
# Test multiple cycles
all_passed = True
for cycle in range(3):
print(f"\nCycle {cycle + 1}:")
# Merge
rotation_layer.merge(safe_merge=True)
with torch.no_grad():
merged_output = rotation_layer(x)
merge_close = torch.allclose(original_output, merged_output, atol=1e-4, rtol=1e-3)
print(f" Merge: {'βœ… PASS' if merge_close else '❌ FAIL'}")
# Unmerge
rotation_layer.unmerge()
with torch.no_grad():
unmerged_output = rotation_layer(x)
unmerge_close = torch.allclose(original_output, unmerged_output, atol=1e-4, rtol=1e-3)
print(f" Unmerge: {'βœ… PASS' if unmerge_close else '❌ FAIL'}")
all_passed = all_passed and merge_close and unmerge_close
return all_passed
def test_with_different_dtypes():
"""
Test merging with different data types.
"""
print("\n" + "="*60)
print("Testing Different Data Types")
print("="*60)
torch.manual_seed(42)
dtypes = [torch.float32, torch.float16, torch.bfloat16]
all_passed = True
for dtype in dtypes:
print(f"\nTesting with dtype: {dtype}")
in_features = 256
out_features = 512
r = 4
num_rotations = 4
base_layer = nn.Linear(in_features, out_features, bias=True)
base_layer = base_layer.to(dtype)
rotation_layer = Linear(
base_layer=base_layer,
adapter_name="default",
r=r,
T=1.0,
num_rotations=num_rotations
)
rotation_layer = rotation_layer.to(dtype)
x = torch.randn(4, 8, in_features, dtype=dtype)
rotation_layer.eval()
with torch.no_grad():
output_before = rotation_layer(x)
rotation_layer.merge(safe_merge=True)
output_after = rotation_layer(x)
# Adjust tolerances based on dtype
if dtype == torch.float32:
atol, rtol = 1e-5, 1e-4
elif dtype == torch.float16:
atol, rtol = 1e-2, 1e-2
else: # bfloat16
atol, rtol = 1e-2, 1e-2
are_close = torch.allclose(output_before, output_after, atol=atol, rtol=rtol)
if are_close:
print(f" βœ… PASS")
else:
max_diff = (output_after - output_before).abs().max().item()
print(f" ❌ FAIL (max diff: {max_diff:.6e})")
all_passed = all_passed and are_close
return all_passed
if __name__ == "__main__":
print("\n" + "="*60)
print("ROTATION LAYER MERGE TEST SUITE")
print("="*60)
results = {}
# Run all tests
results["basic_merge"] = test_rotation_merge()
results["multiple_cycles"] = test_multiple_merges()
results["different_dtypes"] = test_with_different_dtypes()
# Print summary
print("\n" + "="*60)
print("FINAL SUMMARY")
print("="*60)
for test_name, passed in results.items():
status = "βœ… PASS" if passed else "❌ FAIL"
print(f"{test_name}: {status}")
all_passed = all(results.values())
print("\n" + "="*60)
if all_passed:
print("πŸŽ‰ ALL TESTS PASSED!")
else:
print("⚠️ SOME TESTS FAILED")
print("="*60)