bee / scripts /train_agi.py
ceocxx's picture
chore: deploy Bee API backend (bee/, Dockerfile, requirements)
db82745 verified
"""Train Bee AGI — full pre-training with MoE, SSM, Memory, Reasoning, Domain Experts, Compression, and Self-Healing.
This script implements a meta-learning-aware training loop where the model
learns to improve itself through:
- Curriculum difficulty scaling
- Online data mixture rebalancing (based on domain router confidence)
- Self-healing diagnostics (gradient checks, LR auto-tune, rollback)
- Compression-aware loss (hierarchical VQ reconstruction)
- Auxiliary MoE load-balancing losses
"""
import argparse
import logging
import math
import os
import sys
from pathlib import Path
import torch
import torch.nn.functional as F
from datasets import load_dataset, interleave_datasets
from transformers import (
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
set_seed,
get_linear_schedule_with_warmup,
)
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from bee.agi_register import register_agi
from bee.agi_config import BeeAGIConfig
from bee.agi_model import BeeAGIForCausalLM
from bee.self_heal import BeeSelfHealEngine
register_agi()
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
logger = logging.getLogger("bee.train_agi")
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train Bee AGI from scratch")
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--tokenizer_name", type=str, default="HuggingFaceTB/SmolLM2-135M")
parser.add_argument("--vocab_size", type=int, default=49152)
parser.add_argument("--hidden_size", type=int, default=2048)
parser.add_argument("--num_layers", type=int, default=24)
parser.add_argument("--num_heads", type=int, default=16)
parser.add_argument("--num_kv_heads", type=int, default=4)
parser.add_argument("--intermediate_size", type=int, default=5632)
parser.add_argument("--max_seq_length", type=int, default=8192)
parser.add_argument("--num_experts", type=int, default=8)
parser.add_argument("--experts_per_tok", type=int, default=2)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
parser.add_argument("--learning_rate", type=float, default=3e-4)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument("--warmup_steps", type=int, default=2000)
parser.add_argument("--max_steps", type=int, default=100000)
parser.add_argument("--save_steps", type=int, default=2000)
parser.add_argument("--eval_steps", type=int, default=2000)
parser.add_argument("--logging_steps", type=int, default=50)
parser.add_argument("--bf16", action="store_true", default=True)
parser.add_argument("--gradient_checkpointing", action="store_true", default=True)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--push_to_hub", action="store_true", default=False)
parser.add_argument("--hub_model_id", type=str, default=None)
# Data mixing
parser.add_argument("--data_sources", type=str, nargs="+", default=[
"roneneldan/TinyStories",
"openwebtext",
"codeparrot/github-code",
])
parser.add_argument("--data_probs", type=float, nargs="+", default=None)
parser.add_argument("--domain_tuning", action="store_true", default=True)
return parser.parse_args()
class BeeAGITrainer(Trainer):
"""Custom trainer with self-healing, meta-learning signals, and domain rebalancing."""
def __init__(self, *args, self_heal: BeeSelfHealEngine = None, **kwargs):
super().__init__(*args, **kwargs)
self.self_heal = self_heal
self.domain_loss_tracker = {d: [] for d in self.model.config.domains}
def training_step(self, model, inputs, num_items_in_batch=None):
model.train()
inputs = self._prepare_inputs(inputs)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
if self.args.n_gpu > 1:
loss = loss.mean()
if self.use_apex:
from apex import amp
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
self.accelerator.backward(loss)
# Gradient norm for healing
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0).item()
# Self-heal diagnostics
if self.self_heal is not None:
step = self.state.global_step
lr = self.optimizer.param_groups[0]["lr"]
snapshot = self.self_heal.diagnose(step, loss.item(), grad_norm, lr)
heal_report = self.self_heal.heal(self.optimizer, snapshot)
if heal_report["actions"]:
logger.info("Self-heal actions at step %d: %s", step, heal_report["actions"])
return loss.detach()
def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
# Periodic health summary
if self.self_heal is not None:
summary = self.self_heal.get_summary()
logger.info("Health summary: %s", summary)
return super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)
def main():
args = get_args()
set_seed(args.seed)
config = BeeAGIConfig(
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_hidden_layers=args.num_layers,
num_attention_heads=args.num_heads,
num_key_value_heads=args.num_kv_heads,
intermediate_size=args.intermediate_size,
max_position_embeddings=args.max_seq_length,
num_experts=args.num_experts,
num_experts_per_tok=args.experts_per_tok,
tie_word_embeddings=False,
)
logger.info("Initializing Bee AGI with config: %s", config.to_dict())
model = BeeAGIForCausalLM(config)
n_params = sum(p.numel() for p in model.parameters())
logger.info("Model parameters: %.2fB", n_params / 1e9)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load and interleave datasets
logger.info("Loading datasets: %s", args.data_sources)
datasets = []
for ds_name in args.data_sources:
try:
ds = load_dataset(ds_name, split="train", streaming=True)
datasets.append(ds)
except Exception as e:
logger.warning("Failed to load %s: %s", ds_name, e)
if len(datasets) > 1:
probs = args.data_probs or [1.0 / len(datasets)] * len(datasets)
train_ds = interleave_datasets(datasets, probabilities=probs, seed=args.seed)
elif datasets:
train_ds = datasets[0]
else:
raise RuntimeError("No datasets loaded successfully")
def tokenize_function(examples):
text = examples.get("text", examples.get("content", examples.get("code", "")))
return tokenizer(text, truncation=True, max_length=args.max_seq_length)
train_ds = train_ds.map(tokenize_function, batched=True, remove_columns=list(datasets[0].features.keys()) if datasets else [])
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir=args.output_dir,
overwrite_output_dir=True,
max_steps=args.max_steps,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
warmup_steps=args.warmup_steps,
save_steps=args.save_steps,
logging_steps=args.logging_steps,
save_strategy="steps",
bf16=args.bf16 and torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
gradient_checkpointing=args.gradient_checkpointing,
report_to=["tensorboard"],
push_to_hub=args.push_to_hub,
hub_model_id=args.hub_model_id,
dataloader_num_workers=4,
remove_unused_columns=False,
)
# Enable self-healing
heal_dir = os.path.join(args.output_dir, "self_heal")
self_heal = BeeSelfHealEngine(model, heal_dir, auto_tune_lr=True)
model.enable_self_heal(heal_dir, auto_tune_lr=True)
trainer = BeeAGITrainer(
model=model,
args=training_args,
train_dataset=train_ds,
data_collator=data_collator,
tokenizer=tokenizer,
self_heal=self_heal,
)
logger.info("=== Starting Bee AGI Training ===")
trainer.train()
logger.info("Training complete. Saving final model to %s", args.output_dir)
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
self_heal.export_health_log(os.path.join(args.output_dir, "health_log.jsonl"))
logger.info("Health log exported.")
if __name__ == "__main__":
main()