|
|
import torch |
|
|
import torch.nn as nn |
|
|
from omini.rotation.layer import Linear, Rotation |
|
|
|
|
|
def test_rotation_merge(): |
|
|
""" |
|
|
Test that merging rotation adapter produces the same output as the unmerged version. |
|
|
""" |
|
|
print("="*60) |
|
|
print("Testing Rotation Layer Merge") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
in_features = 512 |
|
|
out_features = 1024 |
|
|
r = 4 |
|
|
num_rotations = 4 |
|
|
T = 1.0 |
|
|
batch_size = 8 |
|
|
seq_len = 16 |
|
|
|
|
|
|
|
|
base_layer = nn.Linear(in_features, out_features, bias=True) |
|
|
|
|
|
|
|
|
rotation_layer = Linear( |
|
|
base_layer=base_layer, |
|
|
adapter_name="default", |
|
|
r=r, |
|
|
T=T, |
|
|
num_rotations=num_rotations |
|
|
) |
|
|
|
|
|
|
|
|
x = torch.randn(batch_size, seq_len, in_features) |
|
|
|
|
|
|
|
|
print("\n" + "-"*60) |
|
|
print("Test 1: Computing output BEFORE merge") |
|
|
print("-"*60) |
|
|
rotation_layer.eval() |
|
|
with torch.no_grad(): |
|
|
output_before = rotation_layer(x) |
|
|
|
|
|
print(f"Output shape: {output_before.shape}") |
|
|
print(f"Output mean: {output_before.mean().item():.6f}") |
|
|
print(f"Output std: {output_before.std().item():.6f}") |
|
|
print(f"Output min: {output_before.min().item():.6f}") |
|
|
print(f"Output max: {output_before.max().item():.6f}") |
|
|
|
|
|
|
|
|
original_weight = base_layer.weight.data.clone() |
|
|
|
|
|
|
|
|
print("\n" + "-"*60) |
|
|
print("Test 2: Merging adapter") |
|
|
print("-"*60) |
|
|
rotation_layer.merge(safe_merge=True, adapter_names=["default"]) |
|
|
print(f"β Adapter merged successfully") |
|
|
print(f"β Merged adapters: {rotation_layer.merged_adapters}") |
|
|
|
|
|
|
|
|
weight_diff = (base_layer.weight.data - original_weight).abs().max().item() |
|
|
print(f"Max weight change: {weight_diff:.6e}") |
|
|
|
|
|
|
|
|
print("\n" + "-"*60) |
|
|
print("Test 3: Computing output AFTER merge") |
|
|
print("-"*60) |
|
|
with torch.no_grad(): |
|
|
output_after = rotation_layer(x) |
|
|
|
|
|
print(f"Output shape: {output_after.shape}") |
|
|
print(f"Output mean: {output_after.mean().item():.6f}") |
|
|
print(f"Output std: {output_after.std().item():.6f}") |
|
|
print(f"Output min: {output_after.min().item():.6f}") |
|
|
print(f"Output max: {output_after.max().item():.6f}") |
|
|
|
|
|
|
|
|
print("\n" + "-"*60) |
|
|
print("Test 4: Comparing outputs") |
|
|
print("-"*60) |
|
|
|
|
|
|
|
|
abs_diff = (output_after - output_before).abs() |
|
|
rel_diff = abs_diff / (output_before.abs() + 1e-8) |
|
|
|
|
|
max_abs_diff = abs_diff.max().item() |
|
|
mean_abs_diff = abs_diff.mean().item() |
|
|
max_rel_diff = rel_diff.max().item() |
|
|
mean_rel_diff = rel_diff.mean().item() |
|
|
|
|
|
print(f"Max absolute difference: {max_abs_diff:.6e}") |
|
|
print(f"Mean absolute difference: {mean_abs_diff:.6e}") |
|
|
print(f"Max relative difference: {max_rel_diff:.6e}") |
|
|
print(f"Mean relative difference: {mean_rel_diff:.6e}") |
|
|
|
|
|
|
|
|
atol = 1e-4 |
|
|
rtol = 1e-3 |
|
|
|
|
|
are_close = torch.allclose(output_before, output_after, atol=atol, rtol=rtol) |
|
|
|
|
|
if are_close: |
|
|
print(f"\nβ
PASS: Outputs are identical (within atol={atol}, rtol={rtol})") |
|
|
else: |
|
|
print(f"\nβ FAIL: Outputs differ significantly") |
|
|
print(f" Expected: atol < {atol}, rtol < {rtol}") |
|
|
print(f" Got: max_abs_diff = {max_abs_diff:.6e}, max_rel_diff = {max_rel_diff:.6e}") |
|
|
|
|
|
|
|
|
print("\n" + "-"*60) |
|
|
print("Test 5: Testing unmerge") |
|
|
print("-"*60) |
|
|
rotation_layer.unmerge() |
|
|
print(f"β Adapter unmerged") |
|
|
print(f"β Merged adapters: {rotation_layer.merged_adapters}") |
|
|
|
|
|
with torch.no_grad(): |
|
|
output_unmerged = rotation_layer(x) |
|
|
|
|
|
unmerge_diff = (output_unmerged - output_before).abs().max().item() |
|
|
print(f"Max difference after unmerge: {unmerge_diff:.6e}") |
|
|
|
|
|
unmerge_close = torch.allclose(output_before, output_unmerged, atol=atol, rtol=rtol) |
|
|
if unmerge_close: |
|
|
print(f"β
PASS: Unmerge restored original behavior") |
|
|
else: |
|
|
print(f"β FAIL: Unmerge did not restore original behavior") |
|
|
|
|
|
|
|
|
weight_restored_diff = (base_layer.weight.data - original_weight).abs().max().item() |
|
|
print(f"Max weight difference after unmerge: {weight_restored_diff:.6e}") |
|
|
|
|
|
weight_restored = torch.allclose(base_layer.weight.data, original_weight, atol=1e-5) |
|
|
if weight_restored: |
|
|
print(f"β
PASS: Original weights restored") |
|
|
else: |
|
|
print(f"β FAIL: Original weights not fully restored") |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("Test Summary") |
|
|
print("="*60) |
|
|
return are_close and unmerge_close and weight_restored |
|
|
|
|
|
|
|
|
def test_multiple_merges(): |
|
|
""" |
|
|
Test merging and unmerging multiple times. |
|
|
""" |
|
|
print("\n" + "="*60) |
|
|
print("Testing Multiple Merge/Unmerge Cycles") |
|
|
print("="*60) |
|
|
|
|
|
torch.manual_seed(42) |
|
|
|
|
|
in_features = 256 |
|
|
out_features = 512 |
|
|
r = 4 |
|
|
num_rotations = 4 |
|
|
|
|
|
base_layer = nn.Linear(in_features, out_features, bias=True) |
|
|
rotation_layer = Linear( |
|
|
base_layer=base_layer, |
|
|
adapter_name="default", |
|
|
r=r, |
|
|
T=1.0, |
|
|
num_rotations=num_rotations |
|
|
) |
|
|
|
|
|
x = torch.randn(4, 8, in_features) |
|
|
rotation_layer.eval() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
original_output = rotation_layer(x) |
|
|
|
|
|
|
|
|
all_passed = True |
|
|
for cycle in range(3): |
|
|
print(f"\nCycle {cycle + 1}:") |
|
|
|
|
|
|
|
|
rotation_layer.merge(safe_merge=True) |
|
|
with torch.no_grad(): |
|
|
merged_output = rotation_layer(x) |
|
|
|
|
|
merge_close = torch.allclose(original_output, merged_output, atol=1e-4, rtol=1e-3) |
|
|
print(f" Merge: {'β
PASS' if merge_close else 'β FAIL'}") |
|
|
|
|
|
|
|
|
rotation_layer.unmerge() |
|
|
with torch.no_grad(): |
|
|
unmerged_output = rotation_layer(x) |
|
|
|
|
|
unmerge_close = torch.allclose(original_output, unmerged_output, atol=1e-4, rtol=1e-3) |
|
|
print(f" Unmerge: {'β
PASS' if unmerge_close else 'β FAIL'}") |
|
|
|
|
|
all_passed = all_passed and merge_close and unmerge_close |
|
|
|
|
|
return all_passed |
|
|
|
|
|
|
|
|
def test_with_different_dtypes(): |
|
|
""" |
|
|
Test merging with different data types. |
|
|
""" |
|
|
print("\n" + "="*60) |
|
|
print("Testing Different Data Types") |
|
|
print("="*60) |
|
|
|
|
|
torch.manual_seed(42) |
|
|
|
|
|
dtypes = [torch.float32, torch.float16, torch.bfloat16] |
|
|
all_passed = True |
|
|
|
|
|
for dtype in dtypes: |
|
|
print(f"\nTesting with dtype: {dtype}") |
|
|
|
|
|
in_features = 256 |
|
|
out_features = 512 |
|
|
r = 4 |
|
|
num_rotations = 4 |
|
|
|
|
|
base_layer = nn.Linear(in_features, out_features, bias=True) |
|
|
base_layer = base_layer.to(dtype) |
|
|
|
|
|
rotation_layer = Linear( |
|
|
base_layer=base_layer, |
|
|
adapter_name="default", |
|
|
r=r, |
|
|
T=1.0, |
|
|
num_rotations=num_rotations |
|
|
) |
|
|
rotation_layer = rotation_layer.to(dtype) |
|
|
|
|
|
x = torch.randn(4, 8, in_features, dtype=dtype) |
|
|
rotation_layer.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
output_before = rotation_layer(x) |
|
|
rotation_layer.merge(safe_merge=True) |
|
|
output_after = rotation_layer(x) |
|
|
|
|
|
|
|
|
if dtype == torch.float32: |
|
|
atol, rtol = 1e-5, 1e-4 |
|
|
elif dtype == torch.float16: |
|
|
atol, rtol = 1e-2, 1e-2 |
|
|
else: |
|
|
atol, rtol = 1e-2, 1e-2 |
|
|
|
|
|
are_close = torch.allclose(output_before, output_after, atol=atol, rtol=rtol) |
|
|
|
|
|
if are_close: |
|
|
print(f" β
PASS") |
|
|
else: |
|
|
max_diff = (output_after - output_before).abs().max().item() |
|
|
print(f" β FAIL (max diff: {max_diff:.6e})") |
|
|
|
|
|
all_passed = all_passed and are_close |
|
|
|
|
|
return all_passed |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("\n" + "="*60) |
|
|
print("ROTATION LAYER MERGE TEST SUITE") |
|
|
print("="*60) |
|
|
|
|
|
results = {} |
|
|
|
|
|
|
|
|
results["basic_merge"] = test_rotation_merge() |
|
|
results["multiple_cycles"] = test_multiple_merges() |
|
|
results["different_dtypes"] = test_with_different_dtypes() |
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("FINAL SUMMARY") |
|
|
print("="*60) |
|
|
|
|
|
for test_name, passed in results.items(): |
|
|
status = "β
PASS" if passed else "β FAIL" |
|
|
print(f"{test_name}: {status}") |
|
|
|
|
|
all_passed = all(results.values()) |
|
|
print("\n" + "="*60) |
|
|
if all_passed: |
|
|
print("π ALL TESTS PASSED!") |
|
|
else: |
|
|
print("β οΈ SOME TESTS FAILED") |
|
|
print("="*60) |