File size: 2,325 Bytes
f0e5200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from __future__ import annotations

import argparse
import json
import os
import time
from pathlib import Path
from typing import List

import numpy as np
import torch

from src.prompting import encode_prompt
from src.quantization_utils import load_fp32_model, load_quant_artifact


def _load_items(root: Path, n: int, seed: int = 42) -> List[dict]:
    data = json.loads((root / "data" / "dev.json").read_text())
    if n >= len(data):
        return data
    rng = np.random.default_rng(seed)
    idxs = rng.choice(len(data), size=n, replace=False)
    return [data[int(i)] for i in idxs]


def _bench_generate(tok, model, items: List[dict], device: str) -> float:
    t0 = time.perf_counter()
    for it in items:
        input_ids = encode_prompt(tok, it["question"], it["db_id"], device=device, max_input_tokens=512).unsqueeze(0)
        _ = model.generate(input_ids=input_ids, max_new_tokens=64, num_beams=4)
    return time.perf_counter() - t0


def main() -> None:
    p = argparse.ArgumentParser(description="Benchmark rollout generation latency for RL loops.")
    p.add_argument("--base_model", default=os.environ.get("BASE_MODEL", "Salesforce/codet5-base"))
    p.add_argument("--adapter", default="")
    p.add_argument("--artifact", default="", help="Quantized artifact dir (optional).")
    p.add_argument("--num_rollouts", type=int, default=128)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--local_only", action="store_true")
    args = p.parse_args()

    device = "cpu"
    root = Path(".")
    items = _load_items(root, args.num_rollouts, args.seed)

    tok, fp32 = load_fp32_model(
        args.base_model,
        adapter_path=args.adapter.strip() or None,
        device=device,
        local_only=args.local_only,
    )
    t_fp32 = _bench_generate(tok, fp32, items, device)
    print(f"fp32: {t_fp32:.2f}s for {len(items)} rollouts ({len(items)/max(t_fp32,1e-9):.2f} rollouts/s)")

    if args.artifact:
        tokq, mq, meta = load_quant_artifact(args.artifact, device=device, local_only=True)
        t_q = _bench_generate(tokq, mq, items, device)
        mode = meta.get("mode", "quant")
        print(f"{mode}: {t_q:.2f}s for {len(items)} rollouts ({len(items)/max(t_q,1e-9):.2f} rollouts/s)")


if __name__ == "__main__":
    torch.set_grad_enabled(False)
    main()