File size: 6,912 Bytes
db82745
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
"""Benchmark Bee against real, publicly available small LLMs.

Measures:
  - Perplexity on TinyStories (lower = better)
  - Forward latency (ms per token)
  - Generation throughput (tok/s)
  - Memory footprint

Models compared:
  - Bee-Nano (random init)
  - Bee-Nano (distilled, if available)
  - GPT-2 124M
  - SmolLM2-135M
  - Qwen2.5-0.5B (if fits)
"""

import argparse
import json
import logging
import os
import sys
import time
from pathlib import Path

import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from bee.register import register
from bee.config import BeeConfig
from bee.modeling_bee import BeeForCausalLM

register()

logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
logger = logging.getLogger("bee.benchmark")


def count_params(model):
    return sum(p.numel() for p in model.parameters())


def measure_perplexity(model, tokenizer, device, max_samples=100, max_length=256):
    """Measure perplexity on TinyStories validation."""
    ds = load_dataset("roneneldan/TinyStories", split="validation", streaming=True)
    ds = ds.take(max_samples)

    total_nll = 0.0
    total_tokens = 0
    model = model.to(device).eval()

    for ex in ds:
        text = ex["text"]
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
        with torch.no_grad():
            out = model(**inputs)
            logits = out.logits if hasattr(out, "logits") else out[0]
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = inputs["input_ids"][:, 1:].contiguous()
            nll = F.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1),
                reduction="sum",
            )
            total_nll += nll.item()
            total_tokens += shift_labels.numel()

    perplexity = torch.exp(torch.tensor(total_nll / total_tokens)).item()
    return perplexity


def measure_generation_speed(model, tokenizer, device, prompt="Once upon a time", max_new_tokens=64):
    """Measure generation throughput."""
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    model = model.to(device).eval()

    # Warmup
    with torch.no_grad():
        _ = model.generate(**inputs, max_new_tokens=4, do_sample=False)

    torch.cuda.synchronize() if device == "cuda" else None
    t0 = time.perf_counter()
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
    torch.cuda.synchronize() if device == "cuda" else None
    t1 = time.perf_counter()

    gen_time = t1 - t0
    tok_per_sec = max_new_tokens / gen_time
    return tok_per_sec, gen_time, out.shape[1]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", type=str, default="mps" if torch.backends.mps.is_available() else "cpu")
    parser.add_argument("--bee_checkpoint", type=str, default=None, help="Distilled Bee checkpoint")
    parser.add_argument("--max_samples", type=int, default=50)
    parser.add_argument("--output", type=str, default="benchmark_results.json")
    args = parser.parse_args()

    results = []
    device = args.device

    # Models to benchmark
    models_to_test = []

    # Bee-Nano (random init)
    logger.info("Preparing Bee-Nano (random init)")
    bee_cfg = BeeConfig(vocab_size=49152, hidden_size=512, num_hidden_layers=8,
                        num_attention_heads=8, intermediate_size=1024, max_position_embeddings=2048)
    bee_random = BeeForCausalLM(bee_cfg)
    models_to_test.append(("Bee-Nano (random)", bee_random, None))

    # Bee-Nano (distilled, if exists)
    if args.bee_checkpoint and os.path.exists(args.bee_checkpoint):
        logger.info("Loading distilled Bee from %s", args.bee_checkpoint)
        bee_distilled = BeeForCausalLM.from_pretrained(args.bee_checkpoint)
        tok = AutoTokenizer.from_pretrained(args.bee_checkpoint)
        models_to_test.append(("Bee-Nano (distilled)", bee_distilled, tok))

    # GPT-2
    try:
        logger.info("Loading GPT-2")
        gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")
        gpt2_tok = AutoTokenizer.from_pretrained("gpt2")
        models_to_test.append(("GPT-2 124M", gpt2, gpt2_tok))
    except Exception as e:
        logger.warning("Failed to load GPT-2: %s", e)

    # SmolLM2-135M
    try:
        logger.info("Loading SmolLM2-135M")
        smol = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M", trust_remote_code=True)
        smol_tok = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M", trust_remote_code=True)
        models_to_test.append(("SmolLM2-135M", smol, smol_tok))
    except Exception as e:
        logger.warning("Failed to load SmolLM2: %s", e)

    # Run benchmarks
    for name, model, tok in models_to_test:
        logger.info("=" * 50)
        logger.info("Benchmarking: %s", name)
        logger.info("=" * 50)

        params = count_params(model)
        logger.info("Parameters: %.2fM", params / 1e6)

        # We need a tokenizer
        if tok is None:
            tok = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M", trust_remote_code=True)
            if tok.pad_token is None:
                tok.pad_token = tok.eos_token

        try:
            ppl = measure_perplexity(model, tok, device, max_samples=args.max_samples)
            logger.info("Perplexity: %.2f", ppl)
        except Exception as e:
            logger.error("Perplexity failed: %s", e)
            ppl = None

        try:
            tps, gen_time, out_len = measure_generation_speed(model, tok, device, max_new_tokens=32)
            logger.info("Generation: %.2f tok/s (%.2f ms for 32 tok)", tps, gen_time * 1000)
        except Exception as e:
            logger.error("Generation speed failed: %s", e)
            tps = gen_time = out_len = None

        results.append({
            "model": name,
            "params_M": params / 1e6,
            "perplexity": ppl,
            "gen_tok_per_sec": tps,
            "gen_time_ms": gen_time * 1000 if gen_time else None,
            "output_tokens": out_len,
        })

    # Save and print summary
    with open(args.output, "w") as f:
        json.dump(results, f, indent=2)

    logger.info("\n" + "=" * 50)
    logger.info("SUMMARY")
    logger.info("=" * 50)
    for r in results:
        ppl_str = f"{r['perplexity']:.2f}" if r['perplexity'] else "N/A"
        tps_str = f"{r['gen_tok_per_sec']:.1f}" if r['gen_tok_per_sec'] else "N/A"
        logger.info("%-25s | %.1fM params | PPL: %s | Gen: %s tok/s",
                    r["model"], r["params_M"], ppl_str, tps_str)

    logger.info("Results saved to %s", args.output)


if __name__ == "__main__":
    main()