Spaces:
Sleeping
Sleeping
File size: 2,325 Bytes
f0e5200 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | from __future__ import annotations
import argparse
import json
import os
import time
from pathlib import Path
from typing import List
import numpy as np
import torch
from src.prompting import encode_prompt
from src.quantization_utils import load_fp32_model, load_quant_artifact
def _load_items(root: Path, n: int, seed: int = 42) -> List[dict]:
data = json.loads((root / "data" / "dev.json").read_text())
if n >= len(data):
return data
rng = np.random.default_rng(seed)
idxs = rng.choice(len(data), size=n, replace=False)
return [data[int(i)] for i in idxs]
def _bench_generate(tok, model, items: List[dict], device: str) -> float:
t0 = time.perf_counter()
for it in items:
input_ids = encode_prompt(tok, it["question"], it["db_id"], device=device, max_input_tokens=512).unsqueeze(0)
_ = model.generate(input_ids=input_ids, max_new_tokens=64, num_beams=4)
return time.perf_counter() - t0
def main() -> None:
p = argparse.ArgumentParser(description="Benchmark rollout generation latency for RL loops.")
p.add_argument("--base_model", default=os.environ.get("BASE_MODEL", "Salesforce/codet5-base"))
p.add_argument("--adapter", default="")
p.add_argument("--artifact", default="", help="Quantized artifact dir (optional).")
p.add_argument("--num_rollouts", type=int, default=128)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--local_only", action="store_true")
args = p.parse_args()
device = "cpu"
root = Path(".")
items = _load_items(root, args.num_rollouts, args.seed)
tok, fp32 = load_fp32_model(
args.base_model,
adapter_path=args.adapter.strip() or None,
device=device,
local_only=args.local_only,
)
t_fp32 = _bench_generate(tok, fp32, items, device)
print(f"fp32: {t_fp32:.2f}s for {len(items)} rollouts ({len(items)/max(t_fp32,1e-9):.2f} rollouts/s)")
if args.artifact:
tokq, mq, meta = load_quant_artifact(args.artifact, device=device, local_only=True)
t_q = _bench_generate(tokq, mq, items, device)
mode = meta.get("mode", "quant")
print(f"{mode}: {t_q:.2f}s for {len(items)} rollouts ({len(items)/max(t_q,1e-9):.2f} rollouts/s)")
if __name__ == "__main__":
torch.set_grad_enabled(False)
main()
|