Spaces:
Sleeping
Sleeping
File size: 1,346 Bytes
f0e5200 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | from __future__ import annotations
import argparse
import json
import time
from pathlib import Path
from src.quantized_text2sql_engine import QuantizedText2SQLEngine
def main() -> None:
p = argparse.ArgumentParser(description="Production-style inference harness for quantized artifacts.")
p.add_argument("--artifact", required=True, help="Quant artifact dir from scripts/quantize_export.py")
p.add_argument("--num_samples", type=int, default=128)
p.add_argument("--out", default="results/task5_quant_infer.json")
args = p.parse_args()
root = Path(".")
dev = json.loads((root / "data" / "dev.json").read_text())
dev = dev[: args.num_samples]
engine = QuantizedText2SQLEngine(args.artifact, device="cpu")
pairs = [(x["question"], x["db_id"]) for x in dev]
t0 = time.perf_counter()
results = engine.ask_batch_execute(pairs)
dt = time.perf_counter() - t0
out = {
"n": len(results),
"seconds": dt,
"qps": len(results) / max(dt, 1e-9),
"artifact": args.artifact,
"meta": engine.meta,
"results": results[:10], # sample
}
out_path = Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(out, indent=2))
print(json.dumps(out, indent=2))
if __name__ == "__main__":
main()
|