| | """ |
| | RLE Compression Extension for BitTransformerLM |
| | ============================================== |
| | |
| | Advanced Run-Length Encoding compression module with multiple encoding schemes, |
| | adaptive compression, and training integration for BitTransformerLM. |
| | |
| | Key features: |
| | - Multiple RLE encoding schemes (basic, delta, hierarchical) |
| | - Adaptive compression with quality thresholds |
| | - Training integration with compression-aware loss |
| | - Batch processing and vectorized operations |
| | - Compatible with BitTransformerLM's training infrastructure |
| | """ |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from typing import List, Tuple, Optional, Dict, Any, Union |
| | import warnings |
| | import math |
| | from collections import defaultdict |
| | import numpy as np |
| |
|
| |
|
| | class RLEEncoder: |
| | """ |
| | Advanced Run-Length Encoder with multiple encoding schemes. |
| | |
| | Supports: |
| | - Basic RLE: (value, count) pairs |
| | - Delta RLE: Differences between consecutive runs |
| | - Hierarchical RLE: Multi-level compression |
| | - Adaptive RLE: Chooses best scheme based on data |
| | """ |
| | |
| | def __init__( |
| | self, |
| | scheme: str = "adaptive", |
| | min_run_length: int = 2, |
| | max_value: int = 255, |
| | delta_threshold: float = 0.7, |
| | hierarchical_levels: int = 2, |
| | ): |
| | """ |
| | Args: |
| | scheme: Encoding scheme ('basic', 'delta', 'hierarchical', 'adaptive') |
| | min_run_length: Minimum run length to compress |
| | max_value: Maximum value for encoding |
| | delta_threshold: Compression ratio threshold for delta encoding |
| | hierarchical_levels: Number of levels for hierarchical encoding |
| | """ |
| | self.scheme = scheme |
| | self.min_run_length = min_run_length |
| | self.max_value = max_value |
| | self.delta_threshold = delta_threshold |
| | self.hierarchical_levels = hierarchical_levels |
| | |
| | self.stats = { |
| | "total_compressions": 0, |
| | "total_original_size": 0, |
| | "total_compressed_size": 0, |
| | "scheme_usage": defaultdict(int), |
| | } |
| | |
| | def encode_basic_rle(self, data: torch.Tensor) -> torch.Tensor: |
| | """Basic run-length encoding: (value, count) pairs.""" |
| | if data.numel() == 0: |
| | return torch.tensor([], dtype=torch.uint8) |
| | |
| | data_flat = data.flatten() |
| | encoded = [] |
| | |
| | current_val = data_flat[0].item() |
| | current_count = 1 |
| | |
| | for i in range(1, len(data_flat)): |
| | val = data_flat[i].item() |
| | if val == current_val and current_count < 255: |
| | current_count += 1 |
| | else: |
| | if current_count >= self.min_run_length: |
| | encoded.extend([current_val, current_count]) |
| | else: |
| | |
| | for _ in range(current_count): |
| | encoded.append(current_val) |
| | current_val = val |
| | current_count = 1 |
| | |
| | |
| | if current_count >= self.min_run_length: |
| | encoded.extend([current_val, current_count]) |
| | else: |
| | for _ in range(current_count): |
| | encoded.append(current_val) |
| | |
| | return torch.tensor(encoded, dtype=torch.uint8) |
| | |
| | def decode_basic_rle(self, encoded: torch.Tensor, target_length: Optional[int] = None) -> torch.Tensor: |
| | """Decode basic run-length encoded data.""" |
| | if encoded.numel() == 0: |
| | return torch.tensor([], dtype=torch.long) |
| | |
| | decoded = [] |
| | i = 0 |
| | |
| | while i < len(encoded): |
| | if i + 1 < len(encoded): |
| | val = encoded[i].item() |
| | count = encoded[i + 1].item() |
| | |
| | |
| | if count > 1 and count <= 255: |
| | decoded.extend([val] * count) |
| | i += 2 |
| | else: |
| | |
| | decoded.append(val) |
| | i += 1 |
| | else: |
| | decoded.append(encoded[i].item()) |
| | i += 1 |
| | |
| | result = torch.tensor(decoded, dtype=torch.long) |
| | |
| | |
| | if target_length is not None: |
| | if len(result) > target_length: |
| | result = result[:target_length] |
| | elif len(result) < target_length: |
| | result = F.pad(result, (0, target_length - len(result))) |
| | |
| | return result |
| | |
| | def encode_delta_rle(self, data: torch.Tensor) -> torch.Tensor: |
| | """Delta run-length encoding: encode differences between values.""" |
| | if data.numel() <= 1: |
| | return self.encode_basic_rle(data) |
| | |
| | data_flat = data.flatten() |
| | |
| | |
| | deltas = torch.diff(data_flat, prepend=data_flat[0:1]) |
| | |
| | |
| | shifted_deltas = deltas + 128 |
| | shifted_deltas = torch.clamp(shifted_deltas, 0, 255) |
| | |
| | delta_encoded = self.encode_basic_rle(shifted_deltas) |
| | |
| | |
| | result = torch.cat([data_flat[0:1].to(torch.uint8), delta_encoded]) |
| | return result |
| | |
| | def decode_delta_rle(self, encoded: torch.Tensor, target_length: Optional[int] = None) -> torch.Tensor: |
| | """Decode delta run-length encoded data.""" |
| | if encoded.numel() <= 1: |
| | return self.decode_basic_rle(encoded, target_length) |
| | |
| | |
| | first_val = encoded[0].item() |
| | delta_encoded = encoded[1:] |
| | |
| | |
| | deltas = self.decode_basic_rle(delta_encoded) |
| | |
| | |
| | deltas = deltas.float() - 128 |
| | |
| | |
| | if deltas.numel() > 0: |
| | deltas[0] = first_val |
| | result = torch.cumsum(deltas, dim=0).long() |
| | else: |
| | result = torch.tensor([first_val], dtype=torch.long) |
| | |
| | |
| | if target_length is not None: |
| | if len(result) > target_length: |
| | result = result[:target_length] |
| | elif len(result) < target_length: |
| | result = F.pad(result, (0, target_length - len(result))) |
| | |
| | return result |
| | |
| | def encode_hierarchical_rle(self, data: torch.Tensor) -> torch.Tensor: |
| | """Hierarchical RLE: Apply RLE recursively for better compression.""" |
| | current_data = data.clone() |
| | |
| | for level in range(self.hierarchical_levels): |
| | encoded = self.encode_basic_rle(current_data) |
| | |
| | |
| | if encoded.numel() >= current_data.numel() * 0.9: |
| | |
| | break |
| | |
| | current_data = encoded |
| | |
| | return current_data |
| | |
| | def decode_hierarchical_rle(self, encoded: torch.Tensor, target_length: Optional[int] = None, levels: int = None) -> torch.Tensor: |
| | """Decode hierarchical RLE data.""" |
| | if levels is None: |
| | levels = self.hierarchical_levels |
| | |
| | current_data = encoded.clone() |
| | |
| | for level in range(levels): |
| | try: |
| | current_data = self.decode_basic_rle(current_data) |
| | except Exception: |
| | |
| | break |
| | |
| | |
| | if target_length is not None and current_data.numel() != target_length: |
| | if current_data.numel() > target_length: |
| | current_data = current_data[:target_length] |
| | else: |
| | current_data = F.pad(current_data, (0, target_length - current_data.numel())) |
| | |
| | return current_data |
| | |
| | def encode(self, data: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, Any]]: |
| | """ |
| | Encode data using the configured scheme. |
| | |
| | Args: |
| | data: Input tensor to compress |
| | |
| | Returns: |
| | Tuple of (encoded_data, metadata) |
| | """ |
| | original_shape = data.shape |
| | original_size = data.numel() |
| | |
| | if self.scheme == "basic": |
| | encoded = self.encode_basic_rle(data) |
| | scheme_used = "basic" |
| | elif self.scheme == "delta": |
| | encoded = self.encode_delta_rle(data) |
| | scheme_used = "delta" |
| | elif self.scheme == "hierarchical": |
| | encoded = self.encode_hierarchical_rle(data) |
| | scheme_used = "hierarchical" |
| | elif self.scheme == "adaptive": |
| | |
| | basic_encoded = self.encode_basic_rle(data) |
| | delta_encoded = self.encode_delta_rle(data) |
| | hierarchical_encoded = self.encode_hierarchical_rle(data) |
| | |
| | candidates = { |
| | "basic": basic_encoded, |
| | "delta": delta_encoded, |
| | "hierarchical": hierarchical_encoded, |
| | } |
| | |
| | |
| | best_scheme = min(candidates.keys(), key=lambda k: candidates[k].numel()) |
| | encoded = candidates[best_scheme] |
| | scheme_used = best_scheme |
| | else: |
| | raise ValueError(f"Unknown encoding scheme: {self.scheme}") |
| | |
| | |
| | self.stats["total_compressions"] += 1 |
| | self.stats["total_original_size"] += original_size |
| | self.stats["total_compressed_size"] += encoded.numel() |
| | self.stats["scheme_usage"][scheme_used] += 1 |
| | |
| | metadata = { |
| | "scheme": scheme_used, |
| | "original_shape": original_shape, |
| | "original_size": original_size, |
| | "compressed_size": encoded.numel(), |
| | "compression_ratio": encoded.numel() / original_size if original_size > 0 else 1.0, |
| | } |
| | |
| | return encoded, metadata |
| | |
| | def decode(self, encoded: torch.Tensor, metadata: Dict[str, Any]) -> torch.Tensor: |
| | """ |
| | Decode compressed data using metadata. |
| | |
| | Args: |
| | encoded: Compressed data |
| | metadata: Metadata from encoding |
| | |
| | Returns: |
| | Decoded tensor |
| | """ |
| | scheme = metadata["scheme"] |
| | original_shape = metadata["original_shape"] |
| | target_length = math.prod(original_shape) if original_shape else None |
| | |
| | if scheme == "basic": |
| | decoded = self.decode_basic_rle(encoded, target_length) |
| | elif scheme == "delta": |
| | decoded = self.decode_delta_rle(encoded, target_length) |
| | elif scheme == "hierarchical": |
| | decoded = self.decode_hierarchical_rle(encoded, target_length) |
| | else: |
| | raise ValueError(f"Unknown decoding scheme: {scheme}") |
| | |
| | |
| | if original_shape and decoded.numel() >= math.prod(original_shape): |
| | decoded = decoded[:math.prod(original_shape)].reshape(original_shape) |
| | |
| | return decoded |
| | |
| | def get_compression_stats(self) -> Dict[str, float]: |
| | """Get compression statistics.""" |
| | if self.stats["total_original_size"] == 0: |
| | return {"average_compression_ratio": 1.0, "total_savings": 0.0} |
| | |
| | avg_ratio = self.stats["total_compressed_size"] / self.stats["total_original_size"] |
| | total_savings = self.stats["total_original_size"] - self.stats["total_compressed_size"] |
| | |
| | return { |
| | "average_compression_ratio": avg_ratio, |
| | "total_savings": total_savings, |
| | "total_compressions": self.stats["total_compressions"], |
| | "scheme_usage": dict(self.stats["scheme_usage"]), |
| | } |
| |
|
| |
|
| | class CompressedBitDataset(torch.utils.data.Dataset): |
| | """ |
| | Dataset wrapper that applies RLE compression on-the-fly during training. |
| | |
| | This allows for memory-efficient storage of large bit sequences while |
| | maintaining fast access during training. |
| | """ |
| | |
| | def __init__( |
| | self, |
| | data: torch.Tensor, |
| | encoder: RLEEncoder, |
| | compress_probability: float = 0.5, |
| | cache_size: int = 1000, |
| | ): |
| | """ |
| | Args: |
| | data: Original bit sequence data |
| | encoder: RLE encoder instance |
| | compress_probability: Probability of returning compressed data |
| | cache_size: Number of compressed items to cache |
| | """ |
| | self.data = data |
| | self.encoder = encoder |
| | self.compress_probability = compress_probability |
| | self.cache_size = cache_size |
| | self.cache = {} |
| | self.access_count = defaultdict(int) |
| | |
| | def __len__(self): |
| | return len(self.data) |
| | |
| | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Dict[str, Any]]: |
| | """ |
| | Get item with optional compression. |
| | |
| | Returns: |
| | Tuple of (data, metadata) where metadata indicates if compressed |
| | """ |
| | original_item = self.data[idx] |
| | |
| | |
| | if torch.rand(1).item() < self.compress_probability: |
| | |
| | if idx in self.cache: |
| | compressed, metadata = self.cache[idx] |
| | self.access_count[idx] += 1 |
| | metadata["from_cache"] = True |
| | return compressed, metadata |
| | |
| | |
| | compressed, metadata = self.encoder.encode(original_item) |
| | |
| | |
| | if len(self.cache) < self.cache_size: |
| | self.cache[idx] = (compressed, metadata) |
| | elif self.access_count: |
| | |
| | least_accessed = min(self.cache.keys(), key=lambda k: self.access_count[k]) |
| | del self.cache[least_accessed] |
| | del self.access_count[least_accessed] |
| | self.cache[idx] = (compressed, metadata) |
| | |
| | metadata["from_cache"] = False |
| | return compressed, metadata |
| | else: |
| | |
| | metadata = { |
| | "scheme": "uncompressed", |
| | "original_shape": original_item.shape, |
| | "compressed": False, |
| | "from_cache": False, |
| | } |
| | return original_item, metadata |
| |
|
| |
|
| | def create_compression_aware_loss( |
| | base_loss_fn, |
| | compression_penalty: float = 0.01, |
| | quality_threshold: float = 0.8, |
| | ) -> callable: |
| | """ |
| | Create a loss function that penalizes poor compression quality. |
| | |
| | Args: |
| | base_loss_fn: Base loss function (e.g., CrossEntropyLoss) |
| | compression_penalty: Penalty weight for compression artifacts |
| | quality_threshold: Minimum compression quality threshold |
| | |
| | Returns: |
| | Compression-aware loss function |
| | """ |
| | def compression_aware_loss( |
| | logits: torch.Tensor, |
| | targets: torch.Tensor, |
| | metadata_batch: Optional[List[Dict[str, Any]]] = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Compute loss with compression quality penalty. |
| | |
| | Args: |
| | logits: Model output logits |
| | targets: Target labels |
| | metadata_batch: Batch of compression metadata |
| | |
| | Returns: |
| | Adjusted loss tensor |
| | """ |
| | base_loss = base_loss_fn(logits, targets) |
| | |
| | if metadata_batch is None: |
| | return base_loss |
| | |
| | |
| | penalty = 0.0 |
| | compressed_items = 0 |
| | |
| | for metadata in metadata_batch: |
| | if metadata.get("compressed", False): |
| | compressed_items += 1 |
| | compression_ratio = metadata.get("compression_ratio", 1.0) |
| | |
| | |
| | if compression_ratio > quality_threshold: |
| | quality_penalty = (compression_ratio - quality_threshold) ** 2 |
| | penalty += quality_penalty |
| | |
| | if compressed_items > 0: |
| | penalty = penalty / compressed_items |
| | total_loss = base_loss + compression_penalty * penalty |
| | else: |
| | total_loss = base_loss |
| | |
| | return total_loss |
| | |
| | return compression_aware_loss |
| |
|
| |
|
| | def integrate_rle_with_training( |
| | model, |
| | data: torch.Tensor, |
| | encoder_config: Optional[Dict[str, Any]] = None, |
| | compression_config: Optional[Dict[str, Any]] = None, |
| | ) -> Tuple[CompressedBitDataset, callable]: |
| | """ |
| | Integrate RLE compression with BitTransformerLM training. |
| | |
| | Args: |
| | model: BitTransformerLM model |
| | data: Training data tensor |
| | encoder_config: Configuration for RLE encoder |
| | compression_config: Configuration for compression-aware training |
| | |
| | Returns: |
| | Tuple of (compressed_dataset, compression_aware_loss_fn) |
| | """ |
| | |
| | if encoder_config is None: |
| | encoder_config = { |
| | "scheme": "adaptive", |
| | "min_run_length": 2, |
| | "delta_threshold": 0.7, |
| | } |
| | |
| | if compression_config is None: |
| | compression_config = { |
| | "compress_probability": 0.3, |
| | "compression_penalty": 0.01, |
| | "quality_threshold": 0.8, |
| | "cache_size": 1000, |
| | } |
| | |
| | |
| | encoder = RLEEncoder(**encoder_config) |
| | dataset = CompressedBitDataset( |
| | data, |
| | encoder, |
| | compress_probability=compression_config["compress_probability"], |
| | cache_size=compression_config["cache_size"], |
| | ) |
| | |
| | |
| | base_loss = torch.nn.CrossEntropyLoss() |
| | loss_fn = create_compression_aware_loss( |
| | base_loss, |
| | compression_penalty=compression_config["compression_penalty"], |
| | quality_threshold=compression_config["quality_threshold"], |
| | ) |
| | |
| | return dataset, loss_fn |
| |
|
| |
|
| | def benchmark_compression_schemes( |
| | test_data: torch.Tensor, |
| | schemes: List[str] = ["basic", "delta", "hierarchical", "adaptive"], |
| | ) -> Dict[str, Dict[str, float]]: |
| | """ |
| | Benchmark different compression schemes on test data. |
| | |
| | Args: |
| | test_data: Test data tensor |
| | schemes: List of schemes to test |
| | |
| | Returns: |
| | Dictionary with benchmark results for each scheme |
| | """ |
| | results = {} |
| | |
| | for scheme in schemes: |
| | encoder = RLEEncoder(scheme=scheme) |
| | |
| | |
| | try: |
| | compressed, metadata = encoder.encode(test_data) |
| | reconstructed = encoder.decode(compressed, metadata) |
| | |
| | |
| | compression_ratio = compressed.numel() / test_data.numel() |
| | reconstruction_error = torch.mean((test_data.float() - reconstructed.float()) ** 2).item() |
| | |
| | results[scheme] = { |
| | "compression_ratio": compression_ratio, |
| | "reconstruction_error": reconstruction_error, |
| | "compressed_size": compressed.numel(), |
| | "original_size": test_data.numel(), |
| | "success": True, |
| | } |
| | except Exception as e: |
| | results[scheme] = { |
| | "compression_ratio": 1.0, |
| | "reconstruction_error": float("inf"), |
| | "compressed_size": test_data.numel(), |
| | "original_size": test_data.numel(), |
| | "success": False, |
| | "error": str(e), |
| | } |
| | |
| | return results |
| |
|
| |
|
| | |
| | def create_rle_training_config( |
| | scheme: str = "adaptive", |
| | compress_probability: float = 0.3, |
| | compression_penalty: float = 0.01, |
| | **kwargs |
| | ) -> Dict[str, Any]: |
| | """ |
| | Create configuration for RLE-enhanced training. |
| | |
| | Args: |
| | scheme: RLE encoding scheme |
| | compress_probability: Probability of compression during training |
| | compression_penalty: Loss penalty for compression artifacts |
| | **kwargs: Additional configuration options |
| | |
| | Returns: |
| | Dictionary with RLE training configuration |
| | """ |
| | config = { |
| | "compression_type": "rle", |
| | "encoder_config": { |
| | "scheme": scheme, |
| | "min_run_length": kwargs.get("min_run_length", 2), |
| | "delta_threshold": kwargs.get("delta_threshold", 0.7), |
| | "hierarchical_levels": kwargs.get("hierarchical_levels", 2), |
| | }, |
| | "training_config": { |
| | "compress_probability": compress_probability, |
| | "compression_penalty": compression_penalty, |
| | "quality_threshold": kwargs.get("quality_threshold", 0.8), |
| | "cache_size": kwargs.get("cache_size", 1000), |
| | }, |
| | } |
| | |
| | return config |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | print("Testing RLE Compression Module...") |
| | |
| | |
| | test_data = torch.randint(0, 2, (100,)) |
| | |
| | |
| | test_data[20:30] = 1 |
| | test_data[50:70] = 0 |
| | test_data[80:90] = 1 |
| | |
| | print(f"Original data shape: {test_data.shape}") |
| | print(f"Original data: {test_data[:20]}...") |
| | |
| | |
| | schemes = ["basic", "delta", "hierarchical", "adaptive"] |
| | |
| | for scheme in schemes: |
| | print(f"\nTesting {scheme} scheme:") |
| | encoder = RLEEncoder(scheme=scheme) |
| | |
| | try: |
| | |
| | compressed, metadata = encoder.encode(test_data) |
| | print(f" Compressed size: {compressed.numel()}") |
| | print(f" Compression ratio: {metadata['compression_ratio']:.3f}") |
| | |
| | |
| | reconstructed = encoder.decode(compressed, metadata) |
| | |
| | |
| | error = torch.mean((test_data.float() - reconstructed.float()) ** 2) |
| | print(f" Reconstruction error: {error.item():.6f}") |
| | |
| | if error.item() < 1e-6: |
| | print(" β
Perfect reconstruction") |
| | else: |
| | print(" β Reconstruction error detected") |
| | |
| | except Exception as e: |
| | print(f" β Error: {e}") |
| | |
| | |
| | print("\nBenchmarking compression schemes...") |
| | benchmark_results = benchmark_compression_schemes(test_data) |
| | |
| | for scheme, results in benchmark_results.items(): |
| | if results["success"]: |
| | print(f"{scheme:12}: ratio={results['compression_ratio']:.3f}, " |
| | f"error={results['reconstruction_error']:.6f}") |
| | else: |
| | print(f"{scheme:12}: FAILED - {results.get('error', 'Unknown error')}") |
| | |
| | print("\nRLE Compression Module test completed!") |