| """Configuration management for BitTransformerLM.""" |
|
|
| from __future__ import annotations |
|
|
| import os |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import Any, Dict, Optional |
|
|
| import torch |
|
|
| from .types import ( |
| AttentionMask, |
| ChunkSize, |
| DeviceType, |
| DiffusionConfig, |
| GenerationConfig, |
| HiddenSize, |
| NumHeads, |
| NumLayers, |
| QuantizationConfig, |
| SafetyThresholds, |
| SequenceLength, |
| ) |
|
|
|
|
| @dataclass |
| class ModelConfig: |
| """Configuration for BitTransformerLM model architecture. |
| |
| Attributes: |
| d_model: Model dimension for embeddings and attention. |
| nhead: Number of attention heads. |
| num_layers: Number of transformer layers. |
| dim_feedforward: Dimension of feedforward networks. |
| max_seq_len: Maximum sequence length for positional encoding. |
| lambda_K: Weight for negentropy metric in telemetry. |
| lambda_C: Weight for complexity metric in telemetry. |
| lambda_S: Weight for symbiosis metric in telemetry. |
| reversible: Enable reversible layers for memory efficiency. |
| use_checkpoint: Use gradient checkpointing. |
| use_autocast: Use automatic mixed precision. |
| use_act: Enable Adaptive Computation Time. |
| act_threshold: ACT halting threshold. |
| chunk_size: Chunk size for chunked attention (None for full attention). |
| overlap: Overlap size for chunked attention. |
| full_attn_logging: Log full attention matrices for telemetry. |
| """ |
|
|
| d_model: HiddenSize = 128 |
| nhead: NumHeads = 8 |
| num_layers: NumLayers = 4 |
| dim_feedforward: int = 512 |
| max_seq_len: SequenceLength = 1024 |
| lambda_K: float = 1.0 |
| lambda_C: float = 1.0 |
| lambda_S: float = 1.0 |
| reversible: bool = False |
| use_checkpoint: bool = True |
| use_autocast: bool = False |
| use_act: bool = False |
| act_threshold: float = 0.9 |
| chunk_size: ChunkSize = None |
| overlap: int = 0 |
| full_attn_logging: Optional[bool] = None |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| """Convert config to dictionary.""" |
| return { |
| "d_model": self.d_model, |
| "nhead": self.nhead, |
| "num_layers": self.num_layers, |
| "dim_feedforward": self.dim_feedforward, |
| "max_seq_len": self.max_seq_len, |
| "lambda_K": self.lambda_K, |
| "lambda_C": self.lambda_C, |
| "lambda_S": self.lambda_S, |
| "reversible": self.reversible, |
| "use_checkpoint": self.use_checkpoint, |
| "use_autocast": self.use_autocast, |
| "use_act": self.use_act, |
| "act_threshold": self.act_threshold, |
| "chunk_size": self.chunk_size, |
| "overlap": self.overlap, |
| "full_attn_logging": self.full_attn_logging, |
| } |
|
|
| @classmethod |
| def from_dict(cls, config_dict: Dict[str, Any]) -> ModelConfig: |
| """Create config from dictionary.""" |
| return cls(**config_dict) |
|
|
|
|
| @dataclass |
| class TrainingConfig: |
| """Configuration for training BitTransformerLM. |
| |
| Attributes: |
| epochs: Number of training epochs. |
| batch_size: Training batch size. |
| learning_rate: Initial learning rate. |
| weight_decay: Weight decay for regularization. |
| gradient_clip_val: Gradient clipping value. |
| warmup_steps: Number of warmup steps for learning rate. |
| accumulate_grad_batches: Number of gradient accumulation steps. |
| amp: Enable automatic mixed precision. |
| compile_model: Enable PyTorch 2.0 compilation. |
| log_every_n_steps: Logging frequency. |
| val_check_interval: Validation check frequency. |
| save_top_k: Number of best checkpoints to save. |
| """ |
|
|
| epochs: int = 10 |
| batch_size: int = 8 |
| learning_rate: float = 1e-3 |
| weight_decay: float = 0.01 |
| gradient_clip_val: float = 1.0 |
| warmup_steps: int = 100 |
| accumulate_grad_batches: int = 1 |
| amp: bool = False |
| compile_model: bool = False |
| log_every_n_steps: int = 50 |
| val_check_interval: float = 1.0 |
| save_top_k: int = 3 |
|
|
|
|
| @dataclass |
| class SafetyConfig: |
| """Configuration for safety monitoring and thresholds. |
| |
| Attributes: |
| enable_safety: Enable safety monitoring. |
| k_threshold: Negentropy threshold for safety gate. |
| c_threshold: Complexity threshold for safety gate. |
| s_threshold: Symbiosis threshold for safety gate. |
| strict_mode: Enable strict safety enforcement. |
| retry_attempts: Number of retry attempts for failed safety checks. |
| """ |
|
|
| enable_safety: bool = True |
| k_threshold: float = 0.1 |
| c_threshold: float = 0.3 |
| s_threshold: float = 0.5 |
| strict_mode: bool = False |
| retry_attempts: int = 3 |
|
|
| def to_thresholds(self) -> SafetyThresholds: |
| """Convert to SafetyThresholds type.""" |
| return { |
| "k_threshold": self.k_threshold, |
| "c_threshold": self.c_threshold, |
| "s_threshold": self.s_threshold, |
| } |
|
|
|
|
| @dataclass |
| class DataConfig: |
| """Configuration for data processing and loading. |
| |
| Attributes: |
| dataset_path: Path to training dataset. |
| val_dataset_path: Path to validation dataset. |
| num_workers: Number of data loader workers. |
| pin_memory: Pin memory for data loading. |
| prefetch_factor: Prefetch factor for data loading. |
| max_sequence_length: Maximum sequence length to process. |
| compression_prob: Probability of using compressed data. |
| use_parity: Enable parity bit protection. |
| """ |
|
|
| dataset_path: Optional[Path] = None |
| val_dataset_path: Optional[Path] = None |
| num_workers: int = 0 |
| pin_memory: bool = True |
| prefetch_factor: int = 2 |
| max_sequence_length: int = 1024 |
| compression_prob: float = 0.5 |
| use_parity: bool = True |
|
|
|
|
| @dataclass |
| class ExperimentConfig: |
| """Complete configuration for BitTransformerLM experiments. |
| |
| Attributes: |
| model: Model configuration. |
| training: Training configuration. |
| safety: Safety configuration. |
| data: Data configuration. |
| device: Target device for training. |
| seed: Random seed for reproducibility. |
| experiment_name: Name of the experiment. |
| output_dir: Directory for saving outputs. |
| resume_from_checkpoint: Path to checkpoint to resume from. |
| """ |
|
|
| model: ModelConfig = field(default_factory=ModelConfig) |
| training: TrainingConfig = field(default_factory=TrainingConfig) |
| safety: SafetyConfig = field(default_factory=SafetyConfig) |
| data: DataConfig = field(default_factory=DataConfig) |
| device: DeviceType = "auto" |
| seed: int = 42 |
| experiment_name: str = "bit_transformer_experiment" |
| output_dir: Path = Path("./outputs") |
| resume_from_checkpoint: Optional[Path] = None |
|
|
| def __post_init__(self): |
| """Post-initialization to handle device selection and path creation.""" |
| |
| if self.device == "auto": |
| if torch.cuda.is_available(): |
| self.device = "cuda" |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
| self.device = "mps" |
| else: |
| self.device = "cpu" |
|
|
| |
| self.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| """Convert complete config to dictionary.""" |
| return { |
| "model": self.model.to_dict(), |
| "training": self.training.__dict__, |
| "safety": self.safety.__dict__, |
| "data": self.data.__dict__, |
| "device": str(self.device), |
| "seed": self.seed, |
| "experiment_name": self.experiment_name, |
| "output_dir": str(self.output_dir), |
| "resume_from_checkpoint": str(self.resume_from_checkpoint) if self.resume_from_checkpoint else None, |
| } |
|
|
|
|
| |
| def get_small_config() -> ExperimentConfig: |
| """Get configuration for small-scale experiments.""" |
| return ExperimentConfig( |
| model=ModelConfig( |
| d_model=64, |
| nhead=4, |
| num_layers=2, |
| dim_feedforward=256, |
| max_seq_len=256, |
| ), |
| training=TrainingConfig( |
| batch_size=4, |
| learning_rate=1e-3, |
| epochs=5, |
| ), |
| ) |
|
|
|
|
| def get_medium_config() -> ExperimentConfig: |
| """Get configuration for medium-scale experiments.""" |
| return ExperimentConfig( |
| model=ModelConfig( |
| d_model=128, |
| nhead=8, |
| num_layers=4, |
| dim_feedforward=512, |
| max_seq_len=1024, |
| ), |
| training=TrainingConfig( |
| batch_size=8, |
| learning_rate=1e-3, |
| epochs=10, |
| ), |
| ) |
|
|
|
|
| def get_large_config() -> ExperimentConfig: |
| """Get configuration for large-scale experiments.""" |
| return ExperimentConfig( |
| model=ModelConfig( |
| d_model=256, |
| nhead=16, |
| num_layers=8, |
| dim_feedforward=1024, |
| max_seq_len=2048, |
| reversible=True, |
| chunk_size=512, |
| ), |
| training=TrainingConfig( |
| batch_size=16, |
| learning_rate=5e-4, |
| epochs=20, |
| amp=True, |
| compile_model=True, |
| ), |
| ) |
|
|
|
|
| def get_config_from_env() -> ExperimentConfig: |
| """Load configuration from environment variables.""" |
| config = ExperimentConfig() |
|
|
| |
| if os.getenv("BT_D_MODEL"): |
| config.model.d_model = int(os.getenv("BT_D_MODEL")) |
| if os.getenv("BT_NUM_LAYERS"): |
| config.model.num_layers = int(os.getenv("BT_NUM_LAYERS")) |
| if os.getenv("BT_NHEAD"): |
| config.model.nhead = int(os.getenv("BT_NHEAD")) |
|
|
| |
| if os.getenv("BT_BATCH_SIZE"): |
| config.training.batch_size = int(os.getenv("BT_BATCH_SIZE")) |
| if os.getenv("BT_LEARNING_RATE"): |
| config.training.learning_rate = float(os.getenv("BT_LEARNING_RATE")) |
| if os.getenv("BT_EPOCHS"): |
| config.training.epochs = int(os.getenv("BT_EPOCHS")) |
|
|
| |
| if os.getenv("BT_DEVICE"): |
| config.device = os.getenv("BT_DEVICE") |
|
|
| |
| if os.getenv("BT_OUTPUT_DIR"): |
| config.output_dir = Path(os.getenv("BT_OUTPUT_DIR")) |
|
|
| return config |