| | |
| | """ |
| | BREAKTHROUGH BitTransformerLM Training Script |
| | =========================================== |
| | |
| | Using the ACTUAL BitTransformerLM model and training infrastructure, |
| | configured for the Fixed RL Adafactor breakthrough results. |
| | """ |
| |
|
| | import sys |
| | import os |
| | import logging |
| | from pathlib import Path |
| |
|
| | import torch |
| | from datasets import load_dataset |
| | from huggingface_hub import login |
| |
|
| | |
| | sys.path.append('/data') |
| | sys.path.append('/data/BitTransformerLM') |
| |
|
| | from bit_transformer import ( |
| | BitTransformerLM, |
| | text_to_bits, |
| | train_loop, |
| | save_model, |
| | load_model, |
| | set_dropout |
| | ) |
| | from BTLM_Extensions import configure_adafactor_optimizer |
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format='%(asctime)s - %(levelname)s - %(message)s', |
| | handlers=[ |
| | logging.FileHandler('breakthrough_training.log'), |
| | logging.StreamHandler() |
| | ] |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| | def load_and_prepare_dataset(): |
| | """Load HF dataset and convert to bit tensors.""" |
| | logger.info("Loading WCNegentropy/BitTransformerLM dataset...") |
| | |
| | |
| | hf_token = os.getenv('HF_TOKEN') |
| | if hf_token: |
| | login(token=hf_token) |
| | else: |
| | print("Warning: HF_TOKEN environment variable not set") |
| | |
| | |
| | dataset = load_dataset("WCNegentropy/BitTransformerLM") |
| | train_data = dataset['train'] |
| | |
| | logger.info(f"Dataset loaded: {len(train_data)} samples") |
| | |
| | |
| | bit_sequences = [] |
| | for sample in train_data: |
| | if 'bit_sequence' in sample and sample['bit_sequence'] is not None: |
| | |
| | bits = sample['bit_sequence'] |
| | if isinstance(bits, str): |
| | try: |
| | bits = eval(bits) |
| | except: |
| | bits = None |
| | if isinstance(bits, list) and len(bits) > 0: |
| | bit_sequences.append(bits) |
| | else: |
| | |
| | text = sample.get('original_text', '') |
| | if text: |
| | bits = text_to_bits(text) |
| | bit_sequences.append(bits) |
| | else: |
| | |
| | text = sample.get('text', '') or sample.get('original_text', '') |
| | if text: |
| | bits = text_to_bits(text) |
| | bit_sequences.append(bits) |
| | |
| | logger.info(f"Processed {len(bit_sequences)} bit sequences") |
| | |
| | |
| | max_len = 512 |
| | training_sequences = [] |
| | |
| | for bits in bit_sequences: |
| | |
| | for i in range(0, len(bits) - max_len + 1, max_len // 2): |
| | seq = bits[i:i + max_len] |
| | if len(seq) == max_len: |
| | training_sequences.append(seq) |
| | |
| | |
| | data_tensor = torch.tensor(training_sequences, dtype=torch.long) |
| | logger.info(f"Created training tensor: {data_tensor.shape}") |
| | |
| | return data_tensor |
| |
|
| | def create_breakthrough_model(): |
| | """Create the EXACT breakthrough BitTransformerLM configuration.""" |
| | logger.info("Creating breakthrough BitTransformerLM model...") |
| | |
| | |
| | model = BitTransformerLM( |
| | d_model=512, |
| | nhead=16, |
| | num_layers=8, |
| | dim_feedforward=1024, |
| | max_seq_len=512, |
| | reversible=True, |
| | use_checkpoint=True, |
| | use_autocast=True, |
| | use_act=True, |
| | act_threshold=0.9, |
| | lambda_K=0.05, |
| | lambda_C=0.05, |
| | lambda_S=0.05 |
| | ) |
| | |
| | |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | logger.info(f"Model created: {total_params:,} parameters") |
| | logger.info(f"Target: ~16M parameters - {'β' if 15_000_000 <= total_params <= 17_000_000 else 'β'}") |
| | |
| | return model |
| |
|
| | def main(): |
| | """Main training function.""" |
| | logger.info("π STARTING BREAKTHROUGH BITRANSFORMERLM TRAINING!") |
| | logger.info("Using ACTUAL BitTransformerLM model and train_loop") |
| | |
| | |
| | data = load_and_prepare_dataset() |
| | |
| | |
| | model = create_breakthrough_model() |
| | |
| | |
| | logger.info("Configuring Fixed RL Adafactor optimizer...") |
| | optimizer, scheduler = configure_adafactor_optimizer( |
| | model, |
| | lr=1e-3, |
| | weight_decay=0.01, |
| | total_steps=5000 |
| | ) |
| | logger.info("Fixed RL Adafactor configured with LR=0.001") |
| | |
| | |
| | training_config = { |
| | 'epochs': 20, |
| | 'batch_size': 4, |
| | 'accum_steps': 4, |
| | 'amp': True, |
| | 'log': True, |
| | 'compress_prob': 0.0, |
| | 'optimizer': optimizer, |
| | 'scheduler': scheduler |
| | } |
| | |
| | logger.info(f"Training configuration: {training_config}") |
| | logger.info("Starting training loop...") |
| | |
| | |
| | metrics = train_loop( |
| | model=model, |
| | data=data, |
| | **training_config |
| | ) |
| | |
| | |
| | checkpoint_dir = Path('/data/BitTransformerLM/checkpoints') |
| | checkpoint_dir.mkdir(exist_ok=True) |
| | |
| | model_path = checkpoint_dir / 'breakthrough_model.pt' |
| | save_model(model, model_path) |
| | logger.info(f"Model saved to: {model_path}") |
| | |
| | |
| | if metrics: |
| | final_metrics = metrics[-1] |
| | logger.info("π TRAINING COMPLETED!") |
| | logger.info(f"Final raw_loss: {final_metrics['raw_loss']:.6f}") |
| | logger.info(f"Final raw_acc: {final_metrics['raw_acc']:.3f}") |
| | |
| | |
| | if final_metrics['raw_loss'] < 3.0: |
| | logger.info("π BREAKTHROUGH PERFORMANCE ACHIEVED! Loss < 3.0!") |
| | |
| | logger.info("Breakthrough training completed successfully!") |
| |
|
| | if __name__ == "__main__": |
| | main() |