| | |
| | """ |
| | Test script for BTLM_Extensions |
| | =============================== |
| | |
| | Quick test to verify all extensions are working properly. |
| | """ |
| |
|
| | import sys |
| | import os |
| | import torch |
| | import torch.nn as nn |
| |
|
| | |
| | sys.path.append('/data') |
| | sys.path.append('/data/BitTransformerLM') |
| |
|
| | def test_imports(): |
| | """Test that all modules can be imported.""" |
| | print("Testing imports...") |
| | |
| | try: |
| | from BTLM_Extensions import ( |
| | Muon, Lion, Adafactor, |
| | configure_muon_optimizer, |
| | configure_lion_optimizer, |
| | configure_adafactor_optimizer, |
| | RLEEncoder, |
| | extension_manager, |
| | get_package_info |
| | ) |
| | print("β
All imports successful") |
| | return True |
| | except Exception as e: |
| | print(f"β Import failed: {e}") |
| | return False |
| |
|
| | def test_optimizers(): |
| | """Test optimizer functionality.""" |
| | print("\nTesting optimizers...") |
| | |
| | |
| | model = nn.Sequential( |
| | nn.Linear(10, 20), |
| | nn.ReLU(), |
| | nn.Linear(20, 2) |
| | ) |
| | |
| | try: |
| | from BTLM_Extensions import ( |
| | configure_muon_optimizer, |
| | configure_lion_optimizer, |
| | configure_adafactor_optimizer |
| | ) |
| | |
| | |
| | optimizers_to_test = [ |
| | ("muon", configure_muon_optimizer, {"lr": 1e-3}), |
| | ("lion", configure_lion_optimizer, {"lr": 1e-4}), |
| | ("adafactor", configure_adafactor_optimizer, {"lr": 1e-3}), |
| | ] |
| | |
| | for name, config_fn, kwargs in optimizers_to_test: |
| | try: |
| | optimizer, scheduler = config_fn(model, total_steps=100, **kwargs) |
| | |
| | |
| | x = torch.randn(4, 10) |
| | y = torch.randint(0, 2, (4,)) |
| | |
| | pred = model(x) |
| | loss = nn.functional.cross_entropy(pred, y) |
| | loss.backward() |
| | |
| | optimizer.step() |
| | if scheduler: |
| | scheduler.step() |
| | optimizer.zero_grad() |
| | |
| | print(f"β
{name.capitalize()} optimizer working") |
| | |
| | except Exception as e: |
| | print(f"β {name.capitalize()} optimizer failed: {e}") |
| | |
| | return True |
| | |
| | except Exception as e: |
| | print(f"β Optimizer test failed: {e}") |
| | return False |
| |
|
| | def test_rle_compression(): |
| | """Test RLE compression.""" |
| | print("\nTesting RLE compression...") |
| | |
| | try: |
| | from BTLM_Extensions import RLEEncoder, benchmark_compression_schemes |
| | |
| | |
| | test_data = torch.randint(0, 2, (50,)) |
| | |
| | test_data[10:20] = 1 |
| | test_data[30:40] = 0 |
| | |
| | |
| | schemes = ["basic", "delta", "adaptive"] |
| | |
| | for scheme in schemes: |
| | try: |
| | encoder = RLEEncoder(scheme=scheme) |
| | compressed, metadata = encoder.encode(test_data) |
| | reconstructed = encoder.decode(compressed, metadata) |
| | |
| | |
| | error = torch.mean((test_data.float() - reconstructed.float()) ** 2) |
| | |
| | if error.item() < 1e-6: |
| | print(f"β
RLE {scheme} scheme working (ratio: {metadata['compression_ratio']:.3f})") |
| | else: |
| | print(f"β RLE {scheme} scheme reconstruction error: {error.item()}") |
| | |
| | except Exception as e: |
| | print(f"β RLE {scheme} scheme failed: {e}") |
| | |
| | |
| | try: |
| | results = benchmark_compression_schemes(test_data) |
| | print(f"β
RLE benchmark completed ({len(results)} schemes tested)") |
| | except Exception as e: |
| | print(f"β RLE benchmark failed: {e}") |
| | |
| | return True |
| | |
| | except Exception as e: |
| | print(f"β RLE compression test failed: {e}") |
| | return False |
| |
|
| | def test_integration(): |
| | """Test integration features.""" |
| | print("\nTesting integration features...") |
| | |
| | try: |
| | from BTLM_Extensions import extension_manager, get_package_info |
| | |
| | |
| | info = get_package_info() |
| | print(f"β
Package info: {info['name']} v{info['version']}") |
| | |
| | |
| | optimizers = extension_manager.SUPPORTED_OPTIMIZERS |
| | compression = extension_manager.SUPPORTED_COMPRESSION |
| | print(f"β
Extension manager: {len(optimizers)} optimizers, {len(compression)} compression schemes") |
| | |
| | return True |
| | |
| | except Exception as e: |
| | print(f"β Integration test failed: {e}") |
| | return False |
| |
|
| | def test_bittransformerlm_integration(): |
| | """Test integration with BitTransformerLM if available.""" |
| | print("\nTesting BitTransformerLM integration...") |
| | |
| | try: |
| | from bit_transformer import BitTransformerLM |
| | from BTLM_Extensions import configure_optimizer |
| | |
| | |
| | model = BitTransformerLM( |
| | d_model=64, |
| | nhead=4, |
| | num_layers=2, |
| | dim_feedforward=128, |
| | max_seq_len=32 |
| | ) |
| | |
| | |
| | optimizer, scheduler = configure_optimizer("muon", model, lr=1e-3, total_steps=10) |
| | |
| | |
| | test_bits = torch.randint(0, 2, (2, 16)) |
| | logits, telemetry = model(test_bits) |
| | |
| | |
| | pred = logits[:, :-1, :].reshape(-1, 2) |
| | target = test_bits[:, 1:].reshape(-1) |
| | loss = nn.functional.cross_entropy(pred, target) |
| | |
| | loss.backward() |
| | optimizer.step() |
| | if scheduler: |
| | scheduler.step() |
| | |
| | print(f"β
BitTransformerLM integration working (loss: {loss.item():.4f})") |
| | return True |
| | |
| | except ImportError: |
| | print("β οΈ BitTransformerLM not available, skipping integration test") |
| | return True |
| | except Exception as e: |
| | print(f"β BitTransformerLM integration failed: {e}") |
| | return False |
| |
|
| | def main(): |
| | """Run all tests.""" |
| | print("BTLM_Extensions Test Suite") |
| | print("=" * 40) |
| | |
| | tests = [ |
| | test_imports, |
| | test_optimizers, |
| | test_rle_compression, |
| | test_integration, |
| | test_bittransformerlm_integration, |
| | ] |
| | |
| | passed = 0 |
| | total = len(tests) |
| | |
| | for test in tests: |
| | try: |
| | if test(): |
| | passed += 1 |
| | except Exception as e: |
| | print(f"β Test {test.__name__} crashed: {e}") |
| | |
| | print("\n" + "=" * 40) |
| | print(f"Test Results: {passed}/{total} passed") |
| | |
| | if passed == total: |
| | print("π All tests passed! Extensions are working correctly.") |
| | return 0 |
| | else: |
| | print("β οΈ Some tests failed. Check the output above.") |
| | return 1 |
| |
|
| | if __name__ == "__main__": |
| | exit_code = main() |
| | sys.exit(exit_code) |