"""Knowledge distillation from a teacher LLM into Bee-Nano. Runs on MacBook MPS / CPU. Downloads a small teacher (SmolLM2-135M), generates logits on TinyStories, and distills them into Bee using soft-target cross-entropy (temperature-scaled KL divergence). This is how Bee learns WITHOUT weeks of pre-training on a GPU cluster. """ import argparse import json import logging import os import sys import time from pathlib import Path import torch import torch.nn.functional as F from datasets import load_dataset from torch.utils.data import DataLoader from transformers import AutoTokenizer, AutoModelForCausalLM sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from bee.register import register from bee.config import BeeConfig from bee.modeling_bee import BeeForCausalLM register() logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s") logger = logging.getLogger("bee.distill") def get_args(): parser = argparse.ArgumentParser(description="Distill teacher into Bee-Nano") parser.add_argument("--teacher", type=str, default="HuggingFaceTB/SmolLM2-135M", help="HF teacher model") parser.add_argument("--dataset", type=str, default="roneneldan/TinyStories", help="Dataset for distillation") parser.add_argument("--output_dir", type=str, required=True) parser.add_argument("--max_seq_length", type=int, default=256) parser.add_argument("--batch_size", type=int, default=2) parser.add_argument("--num_steps", type=int, default=500) parser.add_argument("--learning_rate", type=float, default=5e-4) parser.add_argument("--temperature", type=float, default=2.0, help="Softmax temperature for distillation") parser.add_argument("--alpha", type=float, default=0.7, help="Weight for distillation loss (1-alpha for ground-truth CE)") parser.add_argument("--device", type=str, default="mps" if torch.backends.mps.is_available() else "cpu") parser.add_argument("--save_every", type=int, default=100) return parser.parse_args() def distill_step(student, teacher, input_ids, attention_mask, temperature, alpha): """Single distillation step. Returns loss dict.""" with torch.no_grad(): teacher_out = teacher(input_ids=input_ids, attention_mask=attention_mask, use_cache=False) teacher_logits = teacher_out.logits / temperature teacher_probs = F.softmax(teacher_logits, dim=-1) student_out = student(input_ids=input_ids, attention_mask=attention_mask, use_cache=False) student_logits = student_out.logits / temperature # Distillation loss: KL(student || teacher) on shifted targets shift_student = student_logits[:, :-1, :].contiguous().view(-1, student_logits.size(-1)) shift_teacher = teacher_probs[:, 1:, :].contiguous().view(-1, teacher_probs.size(-1)) distill_loss = F.kl_div( F.log_softmax(shift_student, dim=-1), shift_teacher, reduction="batchmean", ) * (temperature ** 2) # Ground-truth CE shift_labels = input_ids[:, 1:].contiguous().view(-1) ce_loss = F.cross_entropy(shift_student, shift_labels, ignore_index=-100) loss = alpha * distill_loss + (1 - alpha) * ce_loss return {"loss": loss, "distill": distill_loss.item(), "ce": ce_loss.item()} def main(): args = get_args() os.makedirs(args.output_dir, exist_ok=True) logger.info("Loading teacher: %s", args.teacher) teacher = AutoModelForCausalLM.from_pretrained(args.teacher, trust_remote_code=True) teacher_tokenizer = AutoTokenizer.from_pretrained(args.teacher, trust_remote_code=True) if teacher_tokenizer.pad_token is None: teacher_tokenizer.pad_token = teacher_tokenizer.eos_token teacher = teacher.to(args.device).eval() # Freeze teacher for p in teacher.parameters(): p.requires_grad = False logger.info("Initializing Bee-Nano student") student_cfg = BeeConfig( vocab_size=teacher_tokenizer.vocab_size, hidden_size=512, num_hidden_layers=8, num_attention_heads=8, intermediate_size=1024, max_position_embeddings=2048, ) student = BeeForCausalLM(student_cfg).to(args.device) n_params = sum(p.numel() for p in student.parameters()) logger.info("Student params: %.2fM", n_params / 1e6) optimizer = torch.optim.AdamW(student.parameters(), lr=args.learning_rate) scaler = torch.cuda.amp.GradScaler() if args.device == "cuda" else None logger.info("Loading dataset: %s", args.dataset) ds = load_dataset(args.dataset, split="train", streaming=True) def tokenize(ex): return teacher_tokenizer(ex["text"], truncation=True, max_length=args.max_seq_length, padding="max_length") ds = ds.map(tokenize, remove_columns=["text"]) def collate_fn(examples): input_ids = torch.stack([torch.tensor(ex["input_ids"]) for ex in examples]) attention_mask = torch.stack([torch.tensor(ex["attention_mask"]) for ex in examples]) return {"input_ids": input_ids, "attention_mask": attention_mask} loader = DataLoader(ds, batch_size=args.batch_size, collate_fn=collate_fn) logger.info("Starting distillation: %d steps", args.num_steps) step = 0 losses = [] start_time = time.perf_counter() for batch in loader: if step >= args.num_steps: break input_ids = batch["input_ids"].to(args.device) attention_mask = batch["attention_mask"].to(args.device) optimizer.zero_grad() if scaler: with torch.cuda.amp.autocast(): loss_dict = distill_step(student, teacher, input_ids, attention_mask, args.temperature, args.alpha) scaler.scale(loss_dict["loss"]).backward() scaler.step(optimizer) scaler.update() else: loss_dict = distill_step(student, teacher, input_ids, attention_mask, args.temperature, args.alpha) loss_dict["loss"].backward() optimizer.step() losses.append(loss_dict["loss"].item()) step += 1 if step % 10 == 0: recent = losses[-10:] logger.info("Step %d | loss=%.4f | distill=%.4f | ce=%.4f | tok/s=%.1f", step, sum(recent) / len(recent), loss_dict["distill"], loss_dict["ce"], (step * args.batch_size * args.max_seq_length) / (time.perf_counter() - start_time), ) if step % args.save_every == 0: ckpt_dir = os.path.join(args.output_dir, f"checkpoint-{step}") os.makedirs(ckpt_dir, exist_ok=True) student.save_pretrained(ckpt_dir) teacher_tokenizer.save_pretrained(ckpt_dir) logger.info("Saved checkpoint to %s", ckpt_dir) # Final save student.save_pretrained(args.output_dir) teacher_tokenizer.save_pretrained(args.output_dir) # Save loss curve with open(os.path.join(args.output_dir, "loss_curve.json"), "w") as f: json.dump({"steps": list(range(1, len(losses) + 1)), "losses": losses}, f) logger.info("Distillation complete. Final avg loss (last 50): %.4f", sum(losses[-50:]) / min(len(losses), 50)) if __name__ == "__main__": main()