""" DLM-NL2JSON-4B — Evaluation Script (Simplified) Evaluates the model on the provided test set using an OpenAI-compatible API endpoint. Measures per-category exact match accuracy and average latency. Usage: # Against vLLM / TensorRT-LLM served model python eval_example.py \ --data test_data_lite_200.jsonl \ --base-url http://your-server:8006/v1 \ --model qwen3_4b_6th_norag \ --api-key token-abc123 \ --disable-thinking # Against OpenAI API (GPT-4o baseline) export OPENAI_API_KEY="sk-..." python eval_example.py \ --data test_data_lite_200.jsonl \ --model gpt-4o """ import json, re, time, argparse, os from collections import Counter from typing import Dict, Any, List # ── Prompts ────────────────────────────────────────────── # Import from prompts.py (must be in the same directory) from prompts import ( SYS_CSM_DEFAULT, SYS_CREDIT_DEFAULT, SYS_GIS_DEFAULT, SYS_ALP_DEFAULT, SYS_CPI_DEFAULT, ) # ── Category → (special_token, system_prompt) ──────────── TASK_MAP = { 0: ("", SYS_ALP_DEFAULT), # ALP-A (pattern) 1: ("", SYS_ALP_DEFAULT), # ALP-B (flow) 2: ("", SYS_CSM_DEFAULT), # CSM (consumer spending) 3: ("", SYS_CREDIT_DEFAULT), # CREDIT-Income 4: ("", SYS_CREDIT_DEFAULT), # CREDIT-Spending 5: ("", SYS_CREDIT_DEFAULT), # CREDIT-Loan/Default 6: ("", SYS_CPI_DEFAULT), # CPI (business status) 9: ("", SYS_GIS_DEFAULT), # GIS-Inflow 10: ("", SYS_GIS_DEFAULT), # GIS-Outflow 11: ("", SYS_GIS_DEFAULT), # GIS-Consumption } CAT_NAMES = { 0: "ALP-A(ptrn)", 1: "ALP-B(flow)", 2: "CSM", 3: "CREDIT-Income", 4: "CREDIT-Spending", 5: "CREDIT-Loan", 6: "CPI", 9: "GIS-Inflow", 10: "GIS-Outflow", 11: "GIS-Consumption", } # ── Required keys per category (for comparison) ───────── REQUIRED_KEYS = { 0: ["base_ym", "region_nm", "ptrn", "sex_cd", "age_cd", "category"], 1: ["base_ym", "region_nm", "flow_cd", "sex_cd", "age_cd", "category"], 2: ["base_ym", "region_nm", "industry_select", "sex_cd", "age_cd", "category"], 3: ["base_ym", "region_nm", "job_cd", "perc_cd", "sex_cd", "age_cd", "category"], 4: ["base_ym", "region_nm", "job_cd", "perc_cd", "sex_cd", "age_cd", "category"], 5: ["base_ym", "region_nm", "job_cd", "perc_cd", "sex_cd", "age_cd", "category"], 6: ["base_ym", "region_nm", "bzc_cd", "cp_cd", "enp_cd", "category"], 9: ["region_nm", "base_ym", "region_count", "category"], 10: ["region_nm", "base_ym", "region_count", "category"], 11: ["region_nm", "base_ym", "industry_category", "category"], } # ── Normalization helpers ──────────────────────────────── def norm_int_list(v): if not isinstance(v, list): return v out = [] for x in v: try: out.append(int(float(str(x).strip()))) except Exception: continue return sorted(set(out)) def norm_dict_of_lists(d): """Normalize industry_select or bzc_cd: {str_key: [int, ...]}""" if not isinstance(d, dict): return d return {str(k).upper() if len(str(k)) == 1 and str(k).isalpha() else str(k): norm_int_list(arr) if isinstance(arr, list) else arr for k, arr in d.items()} def normalize(obj: Dict[str, Any], cat: int) -> Dict[str, Any]: """Normalize prediction/gold for fair comparison (summary excluded).""" o = dict(obj) o.pop("summary", None) for k in ["base_ym", "region_count", "category"]: if k in o and isinstance(o[k], str): try: o[k] = int(o[k]) except ValueError: pass for k in ["sex_cd", "age_cd", "job_cd", "perc_cd", "ptrn", "industry_category", "cp_cd", "enp_cd"]: if k in o: o[k] = norm_int_list(o[k]) if "flow_cd" in o and isinstance(o["flow_cd"], list): o["flow_cd"] = norm_int_list(o["flow_cd"]) for k in ["industry_select", "bzc_cd"]: if k in o: o[k] = norm_dict_of_lists(o[k]) if "region_count" in o: try: o["region_count"] = max(1, min(10, int(o["region_count"]))) except (ValueError, TypeError): pass return o def extract_first_json(text: str): start = text.find("{") if start == -1: return None depth = 0 for i in range(start, len(text)): if text[i] == "{": depth += 1 elif text[i] == "}": depth -= 1 if depth == 0: return text[start:i + 1] return None def compare(pred: Dict, gold: Dict, cat: int): req = REQUIRED_KEYS.get(cat, []) diff = {} for k in req: if pred.get(k, "") != gold.get(k, ""): diff[k] = {"pred": pred.get(k), "gold": gold.get(k)} return len(diff) == 0, diff # ── Main ───────────────────────────────────────────────── def main(): ap = argparse.ArgumentParser(description="DLM-NL2JSON-4B Evaluation") ap.add_argument("--data", required=True, help="Test JSONL file path") ap.add_argument("--base-url", default=None, help="OpenAI-compatible base URL") ap.add_argument("--model", required=True, help="Model name") ap.add_argument("--api-key", default=os.environ.get("OPENAI_API_KEY", ""), help="API key") ap.add_argument("--disable-thinking", action="store_true", help="Pass chat_template_kwargs to disable Qwen3 thinking mode") ap.add_argument("--max-tokens", type=int, default=512) ap.add_argument("--per-cat", type=int, default=999, help="Max samples per category") args = ap.parse_args() import openai client = openai.OpenAI( base_url=args.base_url or None, api_key=args.api_key or "dummy", timeout=60.0, ) # Load test data with open(args.data, encoding="utf-8") as f: raw = [json.loads(line) for line in f] # Group by category and sample from collections import defaultdict by_cat = defaultdict(list) for item in raw: out = item["output"] if isinstance(item["output"], dict) else json.loads(item["output"]) cat = out["category"] by_cat[cat].append({"input": item["input"], "gold": out}) samples = [] for cat in sorted(by_cat): items = by_cat[cat][:args.per_cat] samples.extend([(cat, ex) for ex in items]) print(f"[INFO] Evaluating {len(samples)} samples across {len(by_cat)} categories\n") # Evaluate ok_counts, total_counts = Counter(), Counter() latency_sums = Counter() for idx, (cat, ex) in enumerate(samples, 1): user_in = ex["input"].strip() gold_norm = normalize(ex["gold"], cat) tag, sys_prompt = TASK_MAP[cat] messages = [ {"role": "system", "content": sys_prompt}, {"role": "user", "content": f"{tag}\n{user_in}"}, ] kwargs = dict(model=args.model, messages=messages, max_tokens=args.max_tokens, temperature=0.0) if args.disable_thinking: kwargs["extra_body"] = {"chat_template_kwargs": {"enable_thinking": False}} t0 = time.perf_counter() try: resp = client.chat.completions.create(**kwargs) gen = resp.choices[0].message.content except Exception as e: dt = time.perf_counter() - t0 total_counts[cat] += 1 latency_sums[cat] += dt print(f"[{idx:04d}] {CAT_NAMES.get(cat, cat)} | ERROR: {e}") continue dt = time.perf_counter() - t0 total_counts[cat] += 1 latency_sums[cat] += dt json_str = extract_first_json(gen) or gen.strip() try: pred_obj = json.loads(json_str) except json.JSONDecodeError: print(f"[{idx:04d}] {CAT_NAMES.get(cat, cat)} | PARSE_FAIL | {dt:.2f}s") continue pred_norm = normalize(pred_obj, cat) ok, diff = compare(pred_norm, gold_norm, cat) if ok: ok_counts[cat] += 1 status = "OK" if ok else f"FAIL {list(diff.keys())}" print(f"[{idx:04d}] {CAT_NAMES.get(cat, cat)} | {status} | {dt:.2f}s") # Summary print("\n" + "=" * 50) print("EVALUATION SUMMARY") print("=" * 50) total_ok = total_all = 0 for c in sorted(total_counts): ok = ok_counts[c] tot = total_counts[c] acc = ok / tot if tot else 0 avg_lat = latency_sums[c] / tot if tot else 0 total_ok += ok total_all += tot print(f" {CAT_NAMES.get(c, c):20s}: {ok:4d}/{tot:4d} acc={acc:.1%} avg={avg_lat:.3f}s") overall_acc = total_ok / total_all if total_all else 0 overall_lat = sum(latency_sums.values()) / total_all if total_all else 0 print(f" {'OVERALL':20s}: {total_ok:4d}/{total_all:4d} acc={overall_acc:.1%} avg={overall_lat:.3f}s") if __name__ == "__main__": main()