| """Knowledge distillation from a teacher LLM into Bee-Nano. |
| |
| Runs on MacBook MPS / CPU. Downloads a small teacher (SmolLM2-135M), |
| generates logits on TinyStories, and distills them into Bee using |
| soft-target cross-entropy (temperature-scaled KL divergence). |
| |
| This is how Bee learns WITHOUT weeks of pre-training on a GPU cluster. |
| """ |
|
|
| import argparse |
| import json |
| import logging |
| import os |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
| from datasets import load_dataset |
| from torch.utils.data import DataLoader |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| from bee.register import register |
| from bee.config import BeeConfig |
| from bee.modeling_bee import BeeForCausalLM |
|
|
| register() |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s") |
| logger = logging.getLogger("bee.distill") |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser(description="Distill teacher into Bee-Nano") |
| parser.add_argument("--teacher", type=str, default="HuggingFaceTB/SmolLM2-135M", help="HF teacher model") |
| parser.add_argument("--dataset", type=str, default="roneneldan/TinyStories", help="Dataset for distillation") |
| parser.add_argument("--output_dir", type=str, required=True) |
| parser.add_argument("--max_seq_length", type=int, default=256) |
| parser.add_argument("--batch_size", type=int, default=2) |
| parser.add_argument("--num_steps", type=int, default=500) |
| parser.add_argument("--learning_rate", type=float, default=5e-4) |
| parser.add_argument("--temperature", type=float, default=2.0, help="Softmax temperature for distillation") |
| parser.add_argument("--alpha", type=float, default=0.7, help="Weight for distillation loss (1-alpha for ground-truth CE)") |
| parser.add_argument("--device", type=str, default="mps" if torch.backends.mps.is_available() else "cpu") |
| parser.add_argument("--save_every", type=int, default=100) |
| return parser.parse_args() |
|
|
|
|
| def distill_step(student, teacher, input_ids, attention_mask, temperature, alpha): |
| """Single distillation step. Returns loss dict.""" |
| with torch.no_grad(): |
| teacher_out = teacher(input_ids=input_ids, attention_mask=attention_mask, use_cache=False) |
| teacher_logits = teacher_out.logits / temperature |
| teacher_probs = F.softmax(teacher_logits, dim=-1) |
|
|
| student_out = student(input_ids=input_ids, attention_mask=attention_mask, use_cache=False) |
| student_logits = student_out.logits / temperature |
|
|
| |
| shift_student = student_logits[:, :-1, :].contiguous().view(-1, student_logits.size(-1)) |
| shift_teacher = teacher_probs[:, 1:, :].contiguous().view(-1, teacher_probs.size(-1)) |
|
|
| distill_loss = F.kl_div( |
| F.log_softmax(shift_student, dim=-1), |
| shift_teacher, |
| reduction="batchmean", |
| ) * (temperature ** 2) |
|
|
| |
| shift_labels = input_ids[:, 1:].contiguous().view(-1) |
| ce_loss = F.cross_entropy(shift_student, shift_labels, ignore_index=-100) |
|
|
| loss = alpha * distill_loss + (1 - alpha) * ce_loss |
| return {"loss": loss, "distill": distill_loss.item(), "ce": ce_loss.item()} |
|
|
|
|
| def main(): |
| args = get_args() |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| logger.info("Loading teacher: %s", args.teacher) |
| teacher = AutoModelForCausalLM.from_pretrained(args.teacher, trust_remote_code=True) |
| teacher_tokenizer = AutoTokenizer.from_pretrained(args.teacher, trust_remote_code=True) |
| if teacher_tokenizer.pad_token is None: |
| teacher_tokenizer.pad_token = teacher_tokenizer.eos_token |
| teacher = teacher.to(args.device).eval() |
|
|
| |
| for p in teacher.parameters(): |
| p.requires_grad = False |
|
|
| logger.info("Initializing Bee-Nano student") |
| student_cfg = BeeConfig( |
| vocab_size=teacher_tokenizer.vocab_size, |
| hidden_size=512, |
| num_hidden_layers=8, |
| num_attention_heads=8, |
| intermediate_size=1024, |
| max_position_embeddings=2048, |
| ) |
| student = BeeForCausalLM(student_cfg).to(args.device) |
| n_params = sum(p.numel() for p in student.parameters()) |
| logger.info("Student params: %.2fM", n_params / 1e6) |
|
|
| optimizer = torch.optim.AdamW(student.parameters(), lr=args.learning_rate) |
| scaler = torch.cuda.amp.GradScaler() if args.device == "cuda" else None |
|
|
| logger.info("Loading dataset: %s", args.dataset) |
| ds = load_dataset(args.dataset, split="train", streaming=True) |
|
|
| def tokenize(ex): |
| return teacher_tokenizer(ex["text"], truncation=True, max_length=args.max_seq_length, padding="max_length") |
|
|
| ds = ds.map(tokenize, remove_columns=["text"]) |
| def collate_fn(examples): |
| input_ids = torch.stack([torch.tensor(ex["input_ids"]) for ex in examples]) |
| attention_mask = torch.stack([torch.tensor(ex["attention_mask"]) for ex in examples]) |
| return {"input_ids": input_ids, "attention_mask": attention_mask} |
| loader = DataLoader(ds, batch_size=args.batch_size, collate_fn=collate_fn) |
|
|
| logger.info("Starting distillation: %d steps", args.num_steps) |
| step = 0 |
| losses = [] |
| start_time = time.perf_counter() |
|
|
| for batch in loader: |
| if step >= args.num_steps: |
| break |
|
|
| input_ids = batch["input_ids"].to(args.device) |
| attention_mask = batch["attention_mask"].to(args.device) |
|
|
| optimizer.zero_grad() |
|
|
| if scaler: |
| with torch.cuda.amp.autocast(): |
| loss_dict = distill_step(student, teacher, input_ids, attention_mask, args.temperature, args.alpha) |
| scaler.scale(loss_dict["loss"]).backward() |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| loss_dict = distill_step(student, teacher, input_ids, attention_mask, args.temperature, args.alpha) |
| loss_dict["loss"].backward() |
| optimizer.step() |
|
|
| losses.append(loss_dict["loss"].item()) |
| step += 1 |
|
|
| if step % 10 == 0: |
| recent = losses[-10:] |
| logger.info("Step %d | loss=%.4f | distill=%.4f | ce=%.4f | tok/s=%.1f", |
| step, |
| sum(recent) / len(recent), |
| loss_dict["distill"], |
| loss_dict["ce"], |
| (step * args.batch_size * args.max_seq_length) / (time.perf_counter() - start_time), |
| ) |
|
|
| if step % args.save_every == 0: |
| ckpt_dir = os.path.join(args.output_dir, f"checkpoint-{step}") |
| os.makedirs(ckpt_dir, exist_ok=True) |
| student.save_pretrained(ckpt_dir) |
| teacher_tokenizer.save_pretrained(ckpt_dir) |
| logger.info("Saved checkpoint to %s", ckpt_dir) |
|
|
| |
| student.save_pretrained(args.output_dir) |
| teacher_tokenizer.save_pretrained(args.output_dir) |
|
|
| |
| with open(os.path.join(args.output_dir, "loss_curve.json"), "w") as f: |
| json.dump({"steps": list(range(1, len(losses) + 1)), "losses": losses}, f) |
|
|
| logger.info("Distillation complete. Final avg loss (last 50): %.4f", sum(losses[-50:]) / min(len(losses), 50)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|