"""Train Bee AGI — full pre-training with MoE, SSM, Memory, Reasoning, Domain Experts, Compression, and Self-Healing. This script implements a meta-learning-aware training loop where the model learns to improve itself through: - Curriculum difficulty scaling - Online data mixture rebalancing (based on domain router confidence) - Self-healing diagnostics (gradient checks, LR auto-tune, rollback) - Compression-aware loss (hierarchical VQ reconstruction) - Auxiliary MoE load-balancing losses """ import argparse import logging import math import os import sys from pathlib import Path import torch import torch.nn.functional as F from datasets import load_dataset, interleave_datasets from transformers import ( AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling, set_seed, get_linear_schedule_with_warmup, ) sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from bee.agi_register import register_agi from bee.agi_config import BeeAGIConfig from bee.agi_model import BeeAGIForCausalLM from bee.self_heal import BeeSelfHealEngine register_agi() logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s") logger = logging.getLogger("bee.train_agi") def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Train Bee AGI from scratch") parser.add_argument("--output_dir", type=str, required=True) parser.add_argument("--tokenizer_name", type=str, default="HuggingFaceTB/SmolLM2-135M") parser.add_argument("--vocab_size", type=int, default=49152) parser.add_argument("--hidden_size", type=int, default=2048) parser.add_argument("--num_layers", type=int, default=24) parser.add_argument("--num_heads", type=int, default=16) parser.add_argument("--num_kv_heads", type=int, default=4) parser.add_argument("--intermediate_size", type=int, default=5632) parser.add_argument("--max_seq_length", type=int, default=8192) parser.add_argument("--num_experts", type=int, default=8) parser.add_argument("--experts_per_tok", type=int, default=2) parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--gradient_accumulation_steps", type=int, default=8) parser.add_argument("--learning_rate", type=float, default=3e-4) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument("--warmup_steps", type=int, default=2000) parser.add_argument("--max_steps", type=int, default=100000) parser.add_argument("--save_steps", type=int, default=2000) parser.add_argument("--eval_steps", type=int, default=2000) parser.add_argument("--logging_steps", type=int, default=50) parser.add_argument("--bf16", action="store_true", default=True) parser.add_argument("--gradient_checkpointing", action="store_true", default=True) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--push_to_hub", action="store_true", default=False) parser.add_argument("--hub_model_id", type=str, default=None) # Data mixing parser.add_argument("--data_sources", type=str, nargs="+", default=[ "roneneldan/TinyStories", "openwebtext", "codeparrot/github-code", ]) parser.add_argument("--data_probs", type=float, nargs="+", default=None) parser.add_argument("--domain_tuning", action="store_true", default=True) return parser.parse_args() class BeeAGITrainer(Trainer): """Custom trainer with self-healing, meta-learning signals, and domain rebalancing.""" def __init__(self, *args, self_heal: BeeSelfHealEngine = None, **kwargs): super().__init__(*args, **kwargs) self.self_heal = self_heal self.domain_loss_tracker = {d: [] for d in self.model.config.domains} def training_step(self, model, inputs, num_items_in_batch=None): model.train() inputs = self._prepare_inputs(inputs) with self.compute_loss_context_manager(): loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) if self.args.n_gpu > 1: loss = loss.mean() if self.use_apex: from apex import amp with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: self.accelerator.backward(loss) # Gradient norm for healing grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0).item() # Self-heal diagnostics if self.self_heal is not None: step = self.state.global_step lr = self.optimizer.param_groups[0]["lr"] snapshot = self.self_heal.diagnose(step, loss.item(), grad_norm, lr) heal_report = self.self_heal.heal(self.optimizer, snapshot) if heal_report["actions"]: logger.info("Self-heal actions at step %d: %s", step, heal_report["actions"]) return loss.detach() def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"): # Periodic health summary if self.self_heal is not None: summary = self.self_heal.get_summary() logger.info("Health summary: %s", summary) return super().evaluate(eval_dataset, ignore_keys, metric_key_prefix) def main(): args = get_args() set_seed(args.seed) config = BeeAGIConfig( vocab_size=args.vocab_size, hidden_size=args.hidden_size, num_hidden_layers=args.num_layers, num_attention_heads=args.num_heads, num_key_value_heads=args.num_kv_heads, intermediate_size=args.intermediate_size, max_position_embeddings=args.max_seq_length, num_experts=args.num_experts, num_experts_per_tok=args.experts_per_tok, tie_word_embeddings=False, ) logger.info("Initializing Bee AGI with config: %s", config.to_dict()) model = BeeAGIForCausalLM(config) n_params = sum(p.numel() for p in model.parameters()) logger.info("Model parameters: %.2fB", n_params / 1e9) tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load and interleave datasets logger.info("Loading datasets: %s", args.data_sources) datasets = [] for ds_name in args.data_sources: try: ds = load_dataset(ds_name, split="train", streaming=True) datasets.append(ds) except Exception as e: logger.warning("Failed to load %s: %s", ds_name, e) if len(datasets) > 1: probs = args.data_probs or [1.0 / len(datasets)] * len(datasets) train_ds = interleave_datasets(datasets, probabilities=probs, seed=args.seed) elif datasets: train_ds = datasets[0] else: raise RuntimeError("No datasets loaded successfully") def tokenize_function(examples): text = examples.get("text", examples.get("content", examples.get("code", ""))) return tokenizer(text, truncation=True, max_length=args.max_seq_length) train_ds = train_ds.map(tokenize_function, batched=True, remove_columns=list(datasets[0].features.keys()) if datasets else []) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) training_args = TrainingArguments( output_dir=args.output_dir, overwrite_output_dir=True, max_steps=args.max_steps, num_train_epochs=args.num_train_epochs, per_device_train_batch_size=args.batch_size, per_device_eval_batch_size=args.batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.learning_rate, warmup_steps=args.warmup_steps, save_steps=args.save_steps, logging_steps=args.logging_steps, save_strategy="steps", bf16=args.bf16 and torch.cuda.is_available() and torch.cuda.is_bf16_supported(), gradient_checkpointing=args.gradient_checkpointing, report_to=["tensorboard"], push_to_hub=args.push_to_hub, hub_model_id=args.hub_model_id, dataloader_num_workers=4, remove_unused_columns=False, ) # Enable self-healing heal_dir = os.path.join(args.output_dir, "self_heal") self_heal = BeeSelfHealEngine(model, heal_dir, auto_tune_lr=True) model.enable_self_heal(heal_dir, auto_tune_lr=True) trainer = BeeAGITrainer( model=model, args=training_args, train_dataset=train_ds, data_collator=data_collator, tokenizer=tokenizer, self_heal=self_heal, ) logger.info("=== Starting Bee AGI Training ===") trainer.train() logger.info("Training complete. Saving final model to %s", args.output_dir) trainer.save_model(args.output_dir) tokenizer.save_pretrained(args.output_dir) self_heal.export_health_log(os.path.join(args.output_dir, "health_log.jsonl")) logger.info("Health log exported.") if __name__ == "__main__": main()