| | """
|
| | Unit tests for BitLinear and MultiTernaryLinear layers.
|
| |
|
| | These tests are here to validate the nn.Module implementations and their compatibility with standard PyTorch workflows. Here are the following test cases:
|
| |
|
| | TestBitLinear (8 tests)
|
| | 1. test_initialization - Verifies layer initializes with correct shapes
|
| | 2. test_no_bias_initialization - Tests initialization without bias parameter
|
| | 3. test_forward_shape - Validates output shape correctness
|
| | 4. test_compatibility_with_nn_linear - Tests interface compatibility with nn.Linear
|
| | 5. test_from_linear_conversion - Verifies conversion from nn.Linear to BitLinear
|
| | 6. test_parameter_count - Validates parameter count calculation
|
| | 7. test_weight_values_are_ternary - Ensures weights are in {-1, 0, +1}
|
| | 8. test_gradient_flow - Tests gradient flow for QAT support
|
| |
|
| | TestMultiTernaryLinear (5 tests)
|
| | 1. test_initialization - Verifies k-component initialization
|
| | 2. test_forward_shape - Tests forward pass output shape
|
| | 3. test_k_components - Validates k-component tensor shapes
|
| | 4. test_from_linear_conversion - Tests conversion with k parameter
|
| | 5. test_better_approximation_with_more_k - Validates error decreases with larger k
|
| |
|
| | TestConversionUtilities (3 tests)
|
| | 1. test_convert_simple_model - Tests conversion of Sequential models
|
| | 2. test_convert_nested_model - Tests conversion of nested module hierarchies
|
| | 3. test_inplace_conversion - Tests in-place vs. copy conversion modes
|
| |
|
| | TestLayerIntegration (3 tests)
|
| | 1. test_in_transformer_block - Tests BitLinear in Transformer FFN block
|
| | 2. test_training_step - Validates full training loop compatibility
|
| | 3. test_save_and_load - Tests model serialization and deserialization
|
| |
|
| | TestPerformanceComparison (2 tests - skipped)
|
| | 1. test_memory_usage - Performance benchmark (run manually)
|
| | 2. test_inference_speed - Performance benchmark (run manually)
|
| | """
|
| |
|
| | import pytest
|
| | import torch
|
| | import torch.nn as nn
|
| |
|
| | from bitlinear import BitLinear, MultiTernaryLinear, convert_linear_to_bitlinear
|
| |
|
| |
|
| | class TestBitLinear:
|
| | """Tests for BitLinear layer."""
|
| |
|
| | def test_initialization(self):
|
| | """Test that layer initializes correctly."""
|
| | layer = BitLinear(512, 1024)
|
| | assert layer.in_features == 512
|
| | assert layer.out_features == 1024
|
| | assert layer.bias is not None
|
| | assert layer.W_ternary.shape == (1024, 512)
|
| | assert layer.gamma.shape == (1024,)
|
| |
|
| | def test_no_bias_initialization(self):
|
| | """Test initialization without bias."""
|
| | layer = BitLinear(512, 1024, bias=False)
|
| | assert layer.bias is None
|
| |
|
| | def test_forward_shape(self):
|
| | """Test forward pass produces correct output shape."""
|
| | layer = BitLinear(512, 1024)
|
| | x = torch.randn(32, 128, 512)
|
| | output = layer(x)
|
| | assert output.shape == (32, 128, 1024)
|
| |
|
| | def test_compatibility_with_nn_linear(self):
|
| | """Test that BitLinear can replace nn.Linear in terms of interface."""
|
| | linear = nn.Linear(512, 512)
|
| | bitlinear = BitLinear(512, 512)
|
| |
|
| | x = torch.randn(32, 512)
|
| | out_linear = linear(x)
|
| | out_bitlinear = bitlinear(x)
|
| |
|
| |
|
| | assert out_linear.shape == out_bitlinear.shape
|
| |
|
| | def test_from_linear_conversion(self):
|
| | """Test converting nn.Linear to BitLinear."""
|
| | linear = nn.Linear(512, 1024)
|
| | bitlinear = BitLinear.from_linear(linear)
|
| |
|
| | assert bitlinear.in_features == 512
|
| | assert bitlinear.out_features == 1024
|
| |
|
| |
|
| | x = torch.randn(16, 512)
|
| | output = bitlinear(x)
|
| | assert output.shape == (16, 1024)
|
| |
|
| | def test_parameter_count(self):
|
| | """Test that parameter count is correct."""
|
| | layer = BitLinear(512, 512, bias=True)
|
| |
|
| | expected_params = 512*512 + 512 + 512
|
| | actual_params = sum(p.numel() for p in layer.parameters())
|
| | assert actual_params == expected_params
|
| |
|
| | def test_weight_values_are_ternary(self):
|
| | """Test that stored weights are ternary {-1, 0, +1}."""
|
| | layer = BitLinear(512, 512)
|
| | W_ternary = layer.W_ternary
|
| | unique_values = torch.unique(W_ternary)
|
| | assert set(unique_values.tolist()).issubset({-1.0, 0.0, 1.0})
|
| |
|
| | def test_gradient_flow(self):
|
| | """Test that gradients flow correctly (for QAT)."""
|
| | layer = BitLinear(256, 128)
|
| | x = torch.randn(8, 256, requires_grad=True)
|
| | output = layer(x)
|
| | loss = output.sum()
|
| | loss.backward()
|
| |
|
| | assert x.grad is not None
|
| |
|
| | assert layer.W_ternary.grad is not None
|
| | assert layer.gamma.grad is not None
|
| |
|
| |
|
| | class TestMultiTernaryLinear:
|
| | """Tests for MultiTernaryLinear layer."""
|
| |
|
| | def test_initialization(self):
|
| | """Test layer initialization with k components."""
|
| | layer = MultiTernaryLinear(512, 1024, k=4)
|
| | assert layer.in_features == 512
|
| | assert layer.out_features == 1024
|
| | assert layer.k == 4
|
| | assert layer.W_ternary.shape == (4, 1024, 512)
|
| | assert layer.gammas.shape == (4, 1024)
|
| |
|
| | def test_forward_shape(self):
|
| | """Test forward pass shape."""
|
| | layer = MultiTernaryLinear(512, 1024, k=4)
|
| | x = torch.randn(32, 128, 512)
|
| | output = layer(x)
|
| | assert output.shape == (32, 128, 1024)
|
| |
|
| | def test_k_components(self):
|
| | """Test that layer uses k ternary components."""
|
| | layer = MultiTernaryLinear(512, 512, k=3)
|
| | assert layer.W_ternary.shape == (3, 512, 512)
|
| | assert layer.gammas.shape == (3, 512)
|
| |
|
| | def test_from_linear_conversion(self):
|
| | """Test converting nn.Linear to MultiTernaryLinear."""
|
| | linear = nn.Linear(512, 1024)
|
| | multi_ternary = MultiTernaryLinear.from_linear(linear, k=4)
|
| | assert multi_ternary.k == 4
|
| | assert multi_ternary.in_features == 512
|
| | assert multi_ternary.out_features == 1024
|
| |
|
| | def test_better_approximation_with_more_k(self):
|
| | """Test that larger k provides better approximation of dense layer."""
|
| | linear = nn.Linear(512, 512)
|
| | x = torch.randn(16, 512)
|
| | out_dense = linear(x)
|
| |
|
| |
|
| | errors = []
|
| | for k in [1, 2, 4]:
|
| | multi_ternary = MultiTernaryLinear.from_linear(linear, k=k)
|
| | out_ternary = multi_ternary(x)
|
| | error = torch.norm(out_dense - out_ternary)
|
| | errors.append(error)
|
| |
|
| |
|
| | assert errors[0] > errors[1] and errors[1] > errors[2]
|
| |
|
| |
|
| | class TestConversionUtilities:
|
| | """Tests for model conversion utilities."""
|
| |
|
| | def test_convert_simple_model(self):
|
| | """Test converting a simple Sequential model."""
|
| | model = nn.Sequential(
|
| | nn.Linear(512, 1024),
|
| | nn.ReLU(),
|
| | nn.Linear(1024, 512),
|
| | )
|
| |
|
| | model_bitlinear = convert_linear_to_bitlinear(model, inplace=False)
|
| |
|
| |
|
| | assert isinstance(model_bitlinear[0], BitLinear)
|
| | assert isinstance(model_bitlinear[2], BitLinear)
|
| | assert isinstance(model_bitlinear[1], nn.ReLU)
|
| |
|
| | def test_convert_nested_model(self):
|
| | """Test converting a nested model with submodules."""
|
| | class NestedModel(nn.Module):
|
| | def __init__(self):
|
| | super().__init__()
|
| | self.layer1 = nn.Linear(256, 512)
|
| | self.submodule = nn.Sequential(
|
| | nn.Linear(512, 512),
|
| | nn.ReLU(),
|
| | )
|
| | self.layer2 = nn.Linear(512, 128)
|
| |
|
| | model = NestedModel()
|
| | model_bitlinear = convert_linear_to_bitlinear(model, inplace=False)
|
| |
|
| |
|
| | assert isinstance(model_bitlinear.layer1, BitLinear)
|
| | assert isinstance(model_bitlinear.submodule[0], BitLinear)
|
| | assert isinstance(model_bitlinear.layer2, BitLinear)
|
| |
|
| | def test_inplace_conversion(self):
|
| | """Test in-place vs. copy conversion."""
|
| | model = nn.Sequential(nn.Linear(256, 256))
|
| |
|
| |
|
| | model_copy = convert_linear_to_bitlinear(model, inplace=False)
|
| | assert id(model) != id(model_copy)
|
| | assert isinstance(model[0], nn.Linear)
|
| | assert isinstance(model_copy[0], BitLinear)
|
| |
|
| |
|
| | model2 = nn.Sequential(nn.Linear(256, 256))
|
| | model2_result = convert_linear_to_bitlinear(model2, inplace=True)
|
| | assert id(model2) == id(model2_result)
|
| | assert isinstance(model2[0], BitLinear)
|
| |
|
| |
|
| | class TestLayerIntegration:
|
| | """Integration tests for layers in realistic scenarios."""
|
| |
|
| | def test_in_transformer_block(self):
|
| | """Test BitLinear in a Transformer attention block."""
|
| |
|
| | class TransformerFFN(nn.Module):
|
| | def __init__(self, d_model=256, d_ff=1024):
|
| | super().__init__()
|
| | self.fc1 = BitLinear(d_model, d_ff)
|
| | self.relu = nn.ReLU()
|
| | self.fc2 = BitLinear(d_ff, d_model)
|
| | self.dropout = nn.Dropout(0.1)
|
| |
|
| | def forward(self, x):
|
| | return self.dropout(self.fc2(self.relu(self.fc1(x))))
|
| |
|
| | model = TransformerFFN()
|
| |
|
| |
|
| | batch_size, seq_len, d_model = 8, 32, 256
|
| | x = torch.randn(batch_size, seq_len, d_model)
|
| | output = model(x)
|
| |
|
| |
|
| | assert output.shape == (batch_size, seq_len, d_model)
|
| |
|
| |
|
| | assert set(model.fc1.W_ternary.unique().tolist()).issubset({-1.0, 0.0, 1.0})
|
| | assert set(model.fc2.W_ternary.unique().tolist()).issubset({-1.0, 0.0, 1.0})
|
| |
|
| | def test_training_step(self):
|
| | """Test that layers work in a training loop."""
|
| |
|
| | model = nn.Sequential(
|
| | BitLinear(128, 256),
|
| | nn.ReLU(),
|
| | BitLinear(256, 10),
|
| | )
|
| |
|
| |
|
| | optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
| |
|
| |
|
| | x = torch.randn(16, 128)
|
| | output = model(x)
|
| |
|
| |
|
| | target = torch.randint(0, 10, (16,))
|
| | loss = nn.functional.cross_entropy(output, target)
|
| |
|
| |
|
| | optimizer.zero_grad()
|
| | loss.backward()
|
| |
|
| |
|
| | assert model[0].W_ternary.grad is not None
|
| | assert model[0].gamma.grad is not None
|
| |
|
| |
|
| | optimizer.step()
|
| |
|
| |
|
| | assert torch.isfinite(loss)
|
| |
|
| | def test_save_and_load(self):
|
| | """Test saving and loading models with BitLinear layers."""
|
| | import tempfile
|
| | import os
|
| |
|
| |
|
| | model = nn.Sequential(
|
| | BitLinear(128, 256),
|
| | nn.ReLU(),
|
| | BitLinear(256, 64),
|
| | )
|
| |
|
| |
|
| | with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as f:
|
| | temp_path = f.name
|
| | torch.save(model.state_dict(), temp_path)
|
| |
|
| | try:
|
| |
|
| | model_loaded = nn.Sequential(
|
| | BitLinear(128, 256),
|
| | nn.ReLU(),
|
| | BitLinear(256, 64),
|
| | )
|
| | model_loaded.load_state_dict(torch.load(temp_path))
|
| |
|
| |
|
| | assert torch.allclose(model[0].W_ternary, model_loaded[0].W_ternary)
|
| | assert torch.allclose(model[0].gamma, model_loaded[0].gamma)
|
| | assert torch.allclose(model[2].W_ternary, model_loaded[2].W_ternary)
|
| | assert torch.allclose(model[2].gamma, model_loaded[2].gamma)
|
| |
|
| |
|
| | x = torch.randn(8, 128)
|
| | with torch.no_grad():
|
| | out1 = model(x)
|
| | out2 = model_loaded(x)
|
| | assert torch.allclose(out1, out2)
|
| | finally:
|
| |
|
| | os.unlink(temp_path)
|
| |
|
| |
|
| |
|
| | class TestPerformanceComparison:
|
| | """Tests comparing BitLinear to standard nn.Linear."""
|
| |
|
| | @pytest.mark.skip("Performance test - run manually")
|
| | def test_memory_usage(self):
|
| | """Compare memory usage of BitLinear vs. nn.Linear."""
|
| |
|
| |
|
| |
|
| | pass
|
| |
|
| | @pytest.mark.skip("Performance test - run manually")
|
| | def test_inference_speed(self):
|
| | """Compare inference speed (when CUDA kernels are implemented)."""
|
| |
|
| | pass
|
| |
|