File size: 1,346 Bytes
f0e5200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from __future__ import annotations

import argparse
import json
import time
from pathlib import Path

from src.quantized_text2sql_engine import QuantizedText2SQLEngine


def main() -> None:
    p = argparse.ArgumentParser(description="Production-style inference harness for quantized artifacts.")
    p.add_argument("--artifact", required=True, help="Quant artifact dir from scripts/quantize_export.py")
    p.add_argument("--num_samples", type=int, default=128)
    p.add_argument("--out", default="results/task5_quant_infer.json")
    args = p.parse_args()

    root = Path(".")
    dev = json.loads((root / "data" / "dev.json").read_text())
    dev = dev[: args.num_samples]

    engine = QuantizedText2SQLEngine(args.artifact, device="cpu")
    pairs = [(x["question"], x["db_id"]) for x in dev]

    t0 = time.perf_counter()
    results = engine.ask_batch_execute(pairs)
    dt = time.perf_counter() - t0

    out = {
        "n": len(results),
        "seconds": dt,
        "qps": len(results) / max(dt, 1e-9),
        "artifact": args.artifact,
        "meta": engine.meta,
        "results": results[:10],  # sample
    }

    out_path = Path(args.out)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    out_path.write_text(json.dumps(out, indent=2))
    print(json.dumps(out, indent=2))


if __name__ == "__main__":
    main()