| """Train Bee AGI — full pre-training with MoE, SSM, Memory, Reasoning, Domain Experts, Compression, and Self-Healing. |
| |
| This script implements a meta-learning-aware training loop where the model |
| learns to improve itself through: |
| - Curriculum difficulty scaling |
| - Online data mixture rebalancing (based on domain router confidence) |
| - Self-healing diagnostics (gradient checks, LR auto-tune, rollback) |
| - Compression-aware loss (hierarchical VQ reconstruction) |
| - Auxiliary MoE load-balancing losses |
| """ |
|
|
| import argparse |
| import logging |
| import math |
| import os |
| import sys |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
| from datasets import load_dataset, interleave_datasets |
| from transformers import ( |
| AutoTokenizer, |
| TrainingArguments, |
| Trainer, |
| DataCollatorForLanguageModeling, |
| set_seed, |
| get_linear_schedule_with_warmup, |
| ) |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| from bee.agi_register import register_agi |
| from bee.agi_config import BeeAGIConfig |
| from bee.agi_model import BeeAGIForCausalLM |
| from bee.self_heal import BeeSelfHealEngine |
|
|
| register_agi() |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s") |
| logger = logging.getLogger("bee.train_agi") |
|
|
|
|
| def get_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Train Bee AGI from scratch") |
| parser.add_argument("--output_dir", type=str, required=True) |
| parser.add_argument("--tokenizer_name", type=str, default="HuggingFaceTB/SmolLM2-135M") |
| parser.add_argument("--vocab_size", type=int, default=49152) |
| parser.add_argument("--hidden_size", type=int, default=2048) |
| parser.add_argument("--num_layers", type=int, default=24) |
| parser.add_argument("--num_heads", type=int, default=16) |
| parser.add_argument("--num_kv_heads", type=int, default=4) |
| parser.add_argument("--intermediate_size", type=int, default=5632) |
| parser.add_argument("--max_seq_length", type=int, default=8192) |
| parser.add_argument("--num_experts", type=int, default=8) |
| parser.add_argument("--experts_per_tok", type=int, default=2) |
| parser.add_argument("--batch_size", type=int, default=4) |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=8) |
| parser.add_argument("--learning_rate", type=float, default=3e-4) |
| parser.add_argument("--num_train_epochs", type=int, default=1) |
| parser.add_argument("--warmup_steps", type=int, default=2000) |
| parser.add_argument("--max_steps", type=int, default=100000) |
| parser.add_argument("--save_steps", type=int, default=2000) |
| parser.add_argument("--eval_steps", type=int, default=2000) |
| parser.add_argument("--logging_steps", type=int, default=50) |
| parser.add_argument("--bf16", action="store_true", default=True) |
| parser.add_argument("--gradient_checkpointing", action="store_true", default=True) |
| parser.add_argument("--seed", type=int, default=42) |
| parser.add_argument("--push_to_hub", action="store_true", default=False) |
| parser.add_argument("--hub_model_id", type=str, default=None) |
| |
| parser.add_argument("--data_sources", type=str, nargs="+", default=[ |
| "roneneldan/TinyStories", |
| "openwebtext", |
| "codeparrot/github-code", |
| ]) |
| parser.add_argument("--data_probs", type=float, nargs="+", default=None) |
| parser.add_argument("--domain_tuning", action="store_true", default=True) |
| return parser.parse_args() |
|
|
|
|
| class BeeAGITrainer(Trainer): |
| """Custom trainer with self-healing, meta-learning signals, and domain rebalancing.""" |
|
|
| def __init__(self, *args, self_heal: BeeSelfHealEngine = None, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.self_heal = self_heal |
| self.domain_loss_tracker = {d: [] for d in self.model.config.domains} |
|
|
| def training_step(self, model, inputs, num_items_in_batch=None): |
| model.train() |
| inputs = self._prepare_inputs(inputs) |
|
|
| with self.compute_loss_context_manager(): |
| loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) |
|
|
| if self.args.n_gpu > 1: |
| loss = loss.mean() |
|
|
| if self.use_apex: |
| from apex import amp |
| with amp.scale_loss(loss, self.optimizer) as scaled_loss: |
| scaled_loss.backward() |
| else: |
| self.accelerator.backward(loss) |
|
|
| |
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0).item() |
|
|
| |
| if self.self_heal is not None: |
| step = self.state.global_step |
| lr = self.optimizer.param_groups[0]["lr"] |
| snapshot = self.self_heal.diagnose(step, loss.item(), grad_norm, lr) |
| heal_report = self.self_heal.heal(self.optimizer, snapshot) |
| if heal_report["actions"]: |
| logger.info("Self-heal actions at step %d: %s", step, heal_report["actions"]) |
|
|
| return loss.detach() |
|
|
| def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"): |
| |
| if self.self_heal is not None: |
| summary = self.self_heal.get_summary() |
| logger.info("Health summary: %s", summary) |
| return super().evaluate(eval_dataset, ignore_keys, metric_key_prefix) |
|
|
|
|
| def main(): |
| args = get_args() |
| set_seed(args.seed) |
|
|
| config = BeeAGIConfig( |
| vocab_size=args.vocab_size, |
| hidden_size=args.hidden_size, |
| num_hidden_layers=args.num_layers, |
| num_attention_heads=args.num_heads, |
| num_key_value_heads=args.num_kv_heads, |
| intermediate_size=args.intermediate_size, |
| max_position_embeddings=args.max_seq_length, |
| num_experts=args.num_experts, |
| num_experts_per_tok=args.experts_per_tok, |
| tie_word_embeddings=False, |
| ) |
|
|
| logger.info("Initializing Bee AGI with config: %s", config.to_dict()) |
| model = BeeAGIForCausalLM(config) |
| n_params = sum(p.numel() for p in model.parameters()) |
| logger.info("Model parameters: %.2fB", n_params / 1e9) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| logger.info("Loading datasets: %s", args.data_sources) |
| datasets = [] |
| for ds_name in args.data_sources: |
| try: |
| ds = load_dataset(ds_name, split="train", streaming=True) |
| datasets.append(ds) |
| except Exception as e: |
| logger.warning("Failed to load %s: %s", ds_name, e) |
|
|
| if len(datasets) > 1: |
| probs = args.data_probs or [1.0 / len(datasets)] * len(datasets) |
| train_ds = interleave_datasets(datasets, probabilities=probs, seed=args.seed) |
| elif datasets: |
| train_ds = datasets[0] |
| else: |
| raise RuntimeError("No datasets loaded successfully") |
|
|
| def tokenize_function(examples): |
| text = examples.get("text", examples.get("content", examples.get("code", ""))) |
| return tokenizer(text, truncation=True, max_length=args.max_seq_length) |
|
|
| train_ds = train_ds.map(tokenize_function, batched=True, remove_columns=list(datasets[0].features.keys()) if datasets else []) |
|
|
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
|
|
| training_args = TrainingArguments( |
| output_dir=args.output_dir, |
| overwrite_output_dir=True, |
| max_steps=args.max_steps, |
| num_train_epochs=args.num_train_epochs, |
| per_device_train_batch_size=args.batch_size, |
| per_device_eval_batch_size=args.batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| learning_rate=args.learning_rate, |
| warmup_steps=args.warmup_steps, |
| save_steps=args.save_steps, |
| logging_steps=args.logging_steps, |
| save_strategy="steps", |
| bf16=args.bf16 and torch.cuda.is_available() and torch.cuda.is_bf16_supported(), |
| gradient_checkpointing=args.gradient_checkpointing, |
| report_to=["tensorboard"], |
| push_to_hub=args.push_to_hub, |
| hub_model_id=args.hub_model_id, |
| dataloader_num_workers=4, |
| remove_unused_columns=False, |
| ) |
|
|
| |
| heal_dir = os.path.join(args.output_dir, "self_heal") |
| self_heal = BeeSelfHealEngine(model, heal_dir, auto_tune_lr=True) |
| model.enable_self_heal(heal_dir, auto_tune_lr=True) |
|
|
| trainer = BeeAGITrainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_ds, |
| data_collator=data_collator, |
| tokenizer=tokenizer, |
| self_heal=self_heal, |
| ) |
|
|
| logger.info("=== Starting Bee AGI Training ===") |
| trainer.train() |
| logger.info("Training complete. Saving final model to %s", args.output_dir) |
| trainer.save_model(args.output_dir) |
| tokenizer.save_pretrained(args.output_dir) |
| self_heal.export_health_log(os.path.join(args.output_dir, "health_log.jsonl")) |
| logger.info("Health log exported.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|