| """Benchmark Bee against real, publicly available small LLMs. |
| |
| Measures: |
| - Perplexity on TinyStories (lower = better) |
| - Forward latency (ms per token) |
| - Generation throughput (tok/s) |
| - Memory footprint |
| |
| Models compared: |
| - Bee-Nano (random init) |
| - Bee-Nano (distilled, if available) |
| - GPT-2 124M |
| - SmolLM2-135M |
| - Qwen2.5-0.5B (if fits) |
| """ |
|
|
| import argparse |
| import json |
| import logging |
| import os |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
| from datasets import load_dataset |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| from bee.register import register |
| from bee.config import BeeConfig |
| from bee.modeling_bee import BeeForCausalLM |
|
|
| register() |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s") |
| logger = logging.getLogger("bee.benchmark") |
|
|
|
|
| def count_params(model): |
| return sum(p.numel() for p in model.parameters()) |
|
|
|
|
| def measure_perplexity(model, tokenizer, device, max_samples=100, max_length=256): |
| """Measure perplexity on TinyStories validation.""" |
| ds = load_dataset("roneneldan/TinyStories", split="validation", streaming=True) |
| ds = ds.take(max_samples) |
|
|
| total_nll = 0.0 |
| total_tokens = 0 |
| model = model.to(device).eval() |
|
|
| for ex in ds: |
| text = ex["text"] |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length).to(device) |
| with torch.no_grad(): |
| out = model(**inputs) |
| logits = out.logits if hasattr(out, "logits") else out[0] |
| shift_logits = logits[:, :-1, :].contiguous() |
| shift_labels = inputs["input_ids"][:, 1:].contiguous() |
| nll = F.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| reduction="sum", |
| ) |
| total_nll += nll.item() |
| total_tokens += shift_labels.numel() |
|
|
| perplexity = torch.exp(torch.tensor(total_nll / total_tokens)).item() |
| return perplexity |
|
|
|
|
| def measure_generation_speed(model, tokenizer, device, prompt="Once upon a time", max_new_tokens=64): |
| """Measure generation throughput.""" |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) |
| model = model.to(device).eval() |
|
|
| |
| with torch.no_grad(): |
| _ = model.generate(**inputs, max_new_tokens=4, do_sample=False) |
|
|
| torch.cuda.synchronize() if device == "cuda" else None |
| t0 = time.perf_counter() |
| with torch.no_grad(): |
| out = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) |
| torch.cuda.synchronize() if device == "cuda" else None |
| t1 = time.perf_counter() |
|
|
| gen_time = t1 - t0 |
| tok_per_sec = max_new_tokens / gen_time |
| return tok_per_sec, gen_time, out.shape[1] |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--device", type=str, default="mps" if torch.backends.mps.is_available() else "cpu") |
| parser.add_argument("--bee_checkpoint", type=str, default=None, help="Distilled Bee checkpoint") |
| parser.add_argument("--max_samples", type=int, default=50) |
| parser.add_argument("--output", type=str, default="benchmark_results.json") |
| args = parser.parse_args() |
|
|
| results = [] |
| device = args.device |
|
|
| |
| models_to_test = [] |
|
|
| |
| logger.info("Preparing Bee-Nano (random init)") |
| bee_cfg = BeeConfig(vocab_size=49152, hidden_size=512, num_hidden_layers=8, |
| num_attention_heads=8, intermediate_size=1024, max_position_embeddings=2048) |
| bee_random = BeeForCausalLM(bee_cfg) |
| models_to_test.append(("Bee-Nano (random)", bee_random, None)) |
|
|
| |
| if args.bee_checkpoint and os.path.exists(args.bee_checkpoint): |
| logger.info("Loading distilled Bee from %s", args.bee_checkpoint) |
| bee_distilled = BeeForCausalLM.from_pretrained(args.bee_checkpoint) |
| tok = AutoTokenizer.from_pretrained(args.bee_checkpoint) |
| models_to_test.append(("Bee-Nano (distilled)", bee_distilled, tok)) |
|
|
| |
| try: |
| logger.info("Loading GPT-2") |
| gpt2 = AutoModelForCausalLM.from_pretrained("gpt2") |
| gpt2_tok = AutoTokenizer.from_pretrained("gpt2") |
| models_to_test.append(("GPT-2 124M", gpt2, gpt2_tok)) |
| except Exception as e: |
| logger.warning("Failed to load GPT-2: %s", e) |
|
|
| |
| try: |
| logger.info("Loading SmolLM2-135M") |
| smol = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M", trust_remote_code=True) |
| smol_tok = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M", trust_remote_code=True) |
| models_to_test.append(("SmolLM2-135M", smol, smol_tok)) |
| except Exception as e: |
| logger.warning("Failed to load SmolLM2: %s", e) |
|
|
| |
| for name, model, tok in models_to_test: |
| logger.info("=" * 50) |
| logger.info("Benchmarking: %s", name) |
| logger.info("=" * 50) |
|
|
| params = count_params(model) |
| logger.info("Parameters: %.2fM", params / 1e6) |
|
|
| |
| if tok is None: |
| tok = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M", trust_remote_code=True) |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
|
|
| try: |
| ppl = measure_perplexity(model, tok, device, max_samples=args.max_samples) |
| logger.info("Perplexity: %.2f", ppl) |
| except Exception as e: |
| logger.error("Perplexity failed: %s", e) |
| ppl = None |
|
|
| try: |
| tps, gen_time, out_len = measure_generation_speed(model, tok, device, max_new_tokens=32) |
| logger.info("Generation: %.2f tok/s (%.2f ms for 32 tok)", tps, gen_time * 1000) |
| except Exception as e: |
| logger.error("Generation speed failed: %s", e) |
| tps = gen_time = out_len = None |
|
|
| results.append({ |
| "model": name, |
| "params_M": params / 1e6, |
| "perplexity": ppl, |
| "gen_tok_per_sec": tps, |
| "gen_time_ms": gen_time * 1000 if gen_time else None, |
| "output_tokens": out_len, |
| }) |
|
|
| |
| with open(args.output, "w") as f: |
| json.dump(results, f, indent=2) |
|
|
| logger.info("\n" + "=" * 50) |
| logger.info("SUMMARY") |
| logger.info("=" * 50) |
| for r in results: |
| ppl_str = f"{r['perplexity']:.2f}" if r['perplexity'] else "N/A" |
| tps_str = f"{r['gen_tok_per_sec']:.1f}" if r['gen_tok_per_sec'] else "N/A" |
| logger.info("%-25s | %.1fM params | PPL: %s | Gen: %s tok/s", |
| r["model"], r["params_M"], ppl_str, tps_str) |
|
|
| logger.info("Results saved to %s", args.output) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|