bee / scripts /distill.py
ceocxx's picture
chore: deploy Bee API backend (bee/, Dockerfile, requirements)
db82745 verified
"""Knowledge distillation from a teacher LLM into Bee-Nano.
Runs on MacBook MPS / CPU. Downloads a small teacher (SmolLM2-135M),
generates logits on TinyStories, and distills them into Bee using
soft-target cross-entropy (temperature-scaled KL divergence).
This is how Bee learns WITHOUT weeks of pre-training on a GPU cluster.
"""
import argparse
import json
import logging
import os
import sys
import time
from pathlib import Path
import torch
import torch.nn.functional as F
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from bee.register import register
from bee.config import BeeConfig
from bee.modeling_bee import BeeForCausalLM
register()
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
logger = logging.getLogger("bee.distill")
def get_args():
parser = argparse.ArgumentParser(description="Distill teacher into Bee-Nano")
parser.add_argument("--teacher", type=str, default="HuggingFaceTB/SmolLM2-135M", help="HF teacher model")
parser.add_argument("--dataset", type=str, default="roneneldan/TinyStories", help="Dataset for distillation")
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--max_seq_length", type=int, default=256)
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--num_steps", type=int, default=500)
parser.add_argument("--learning_rate", type=float, default=5e-4)
parser.add_argument("--temperature", type=float, default=2.0, help="Softmax temperature for distillation")
parser.add_argument("--alpha", type=float, default=0.7, help="Weight for distillation loss (1-alpha for ground-truth CE)")
parser.add_argument("--device", type=str, default="mps" if torch.backends.mps.is_available() else "cpu")
parser.add_argument("--save_every", type=int, default=100)
return parser.parse_args()
def distill_step(student, teacher, input_ids, attention_mask, temperature, alpha):
"""Single distillation step. Returns loss dict."""
with torch.no_grad():
teacher_out = teacher(input_ids=input_ids, attention_mask=attention_mask, use_cache=False)
teacher_logits = teacher_out.logits / temperature
teacher_probs = F.softmax(teacher_logits, dim=-1)
student_out = student(input_ids=input_ids, attention_mask=attention_mask, use_cache=False)
student_logits = student_out.logits / temperature
# Distillation loss: KL(student || teacher) on shifted targets
shift_student = student_logits[:, :-1, :].contiguous().view(-1, student_logits.size(-1))
shift_teacher = teacher_probs[:, 1:, :].contiguous().view(-1, teacher_probs.size(-1))
distill_loss = F.kl_div(
F.log_softmax(shift_student, dim=-1),
shift_teacher,
reduction="batchmean",
) * (temperature ** 2)
# Ground-truth CE
shift_labels = input_ids[:, 1:].contiguous().view(-1)
ce_loss = F.cross_entropy(shift_student, shift_labels, ignore_index=-100)
loss = alpha * distill_loss + (1 - alpha) * ce_loss
return {"loss": loss, "distill": distill_loss.item(), "ce": ce_loss.item()}
def main():
args = get_args()
os.makedirs(args.output_dir, exist_ok=True)
logger.info("Loading teacher: %s", args.teacher)
teacher = AutoModelForCausalLM.from_pretrained(args.teacher, trust_remote_code=True)
teacher_tokenizer = AutoTokenizer.from_pretrained(args.teacher, trust_remote_code=True)
if teacher_tokenizer.pad_token is None:
teacher_tokenizer.pad_token = teacher_tokenizer.eos_token
teacher = teacher.to(args.device).eval()
# Freeze teacher
for p in teacher.parameters():
p.requires_grad = False
logger.info("Initializing Bee-Nano student")
student_cfg = BeeConfig(
vocab_size=teacher_tokenizer.vocab_size,
hidden_size=512,
num_hidden_layers=8,
num_attention_heads=8,
intermediate_size=1024,
max_position_embeddings=2048,
)
student = BeeForCausalLM(student_cfg).to(args.device)
n_params = sum(p.numel() for p in student.parameters())
logger.info("Student params: %.2fM", n_params / 1e6)
optimizer = torch.optim.AdamW(student.parameters(), lr=args.learning_rate)
scaler = torch.cuda.amp.GradScaler() if args.device == "cuda" else None
logger.info("Loading dataset: %s", args.dataset)
ds = load_dataset(args.dataset, split="train", streaming=True)
def tokenize(ex):
return teacher_tokenizer(ex["text"], truncation=True, max_length=args.max_seq_length, padding="max_length")
ds = ds.map(tokenize, remove_columns=["text"])
def collate_fn(examples):
input_ids = torch.stack([torch.tensor(ex["input_ids"]) for ex in examples])
attention_mask = torch.stack([torch.tensor(ex["attention_mask"]) for ex in examples])
return {"input_ids": input_ids, "attention_mask": attention_mask}
loader = DataLoader(ds, batch_size=args.batch_size, collate_fn=collate_fn)
logger.info("Starting distillation: %d steps", args.num_steps)
step = 0
losses = []
start_time = time.perf_counter()
for batch in loader:
if step >= args.num_steps:
break
input_ids = batch["input_ids"].to(args.device)
attention_mask = batch["attention_mask"].to(args.device)
optimizer.zero_grad()
if scaler:
with torch.cuda.amp.autocast():
loss_dict = distill_step(student, teacher, input_ids, attention_mask, args.temperature, args.alpha)
scaler.scale(loss_dict["loss"]).backward()
scaler.step(optimizer)
scaler.update()
else:
loss_dict = distill_step(student, teacher, input_ids, attention_mask, args.temperature, args.alpha)
loss_dict["loss"].backward()
optimizer.step()
losses.append(loss_dict["loss"].item())
step += 1
if step % 10 == 0:
recent = losses[-10:]
logger.info("Step %d | loss=%.4f | distill=%.4f | ce=%.4f | tok/s=%.1f",
step,
sum(recent) / len(recent),
loss_dict["distill"],
loss_dict["ce"],
(step * args.batch_size * args.max_seq_length) / (time.perf_counter() - start_time),
)
if step % args.save_every == 0:
ckpt_dir = os.path.join(args.output_dir, f"checkpoint-{step}")
os.makedirs(ckpt_dir, exist_ok=True)
student.save_pretrained(ckpt_dir)
teacher_tokenizer.save_pretrained(ckpt_dir)
logger.info("Saved checkpoint to %s", ckpt_dir)
# Final save
student.save_pretrained(args.output_dir)
teacher_tokenizer.save_pretrained(args.output_dir)
# Save loss curve
with open(os.path.join(args.output_dir, "loss_curve.json"), "w") as f:
json.dump({"steps": list(range(1, len(losses) + 1)), "losses": losses}, f)
logger.info("Distillation complete. Final avg loss (last 50): %.4f", sum(losses[-50:]) / min(len(losses), 50))
if __name__ == "__main__":
main()