| """ |
| DLM-NL2JSON-4B β Evaluation Script (Simplified) |
| |
| Evaluates the model on the provided test set using an OpenAI-compatible API endpoint. |
| Measures per-category exact match accuracy and average latency. |
| |
| Usage: |
| # Against vLLM / TensorRT-LLM served model |
| python eval_example.py \ |
| --data test_data_lite_200.jsonl \ |
| --base-url http://your-server:8006/v1 \ |
| --model qwen3_4b_6th_norag \ |
| --api-key token-abc123 \ |
| --disable-thinking |
| |
| # Against OpenAI API (GPT-4o baseline) |
| export OPENAI_API_KEY="sk-..." |
| python eval_example.py \ |
| --data test_data_lite_200.jsonl \ |
| --model gpt-4o |
| """ |
|
|
| import json, re, time, argparse, os |
| from collections import Counter |
| from typing import Dict, Any, List |
|
|
| |
| |
| from prompts import ( |
| SYS_CSM_DEFAULT, |
| SYS_CREDIT_DEFAULT, |
| SYS_GIS_DEFAULT, |
| SYS_ALP_DEFAULT, |
| SYS_CPI_DEFAULT, |
| ) |
|
|
| |
| TASK_MAP = { |
| 0: ("<TASK_ALP>", SYS_ALP_DEFAULT), |
| 1: ("<TASK_ALP>", SYS_ALP_DEFAULT), |
| 2: ("<TASK_CSM>", SYS_CSM_DEFAULT), |
| 3: ("<TASK_CREDIT>", SYS_CREDIT_DEFAULT), |
| 4: ("<TASK_CREDIT>", SYS_CREDIT_DEFAULT), |
| 5: ("<TASK_CREDIT>", SYS_CREDIT_DEFAULT), |
| 6: ("<TASK_CPI>", SYS_CPI_DEFAULT), |
| 9: ("<TASK_GIS>", SYS_GIS_DEFAULT), |
| 10: ("<TASK_GIS>", SYS_GIS_DEFAULT), |
| 11: ("<TASK_GIS>", SYS_GIS_DEFAULT), |
| } |
|
|
| CAT_NAMES = { |
| 0: "ALP-A(ptrn)", 1: "ALP-B(flow)", 2: "CSM", |
| 3: "CREDIT-Income", 4: "CREDIT-Spending", 5: "CREDIT-Loan", |
| 6: "CPI", 9: "GIS-Inflow", 10: "GIS-Outflow", 11: "GIS-Consumption", |
| } |
|
|
| |
| REQUIRED_KEYS = { |
| 0: ["base_ym", "region_nm", "ptrn", "sex_cd", "age_cd", "category"], |
| 1: ["base_ym", "region_nm", "flow_cd", "sex_cd", "age_cd", "category"], |
| 2: ["base_ym", "region_nm", "industry_select", "sex_cd", "age_cd", "category"], |
| 3: ["base_ym", "region_nm", "job_cd", "perc_cd", "sex_cd", "age_cd", "category"], |
| 4: ["base_ym", "region_nm", "job_cd", "perc_cd", "sex_cd", "age_cd", "category"], |
| 5: ["base_ym", "region_nm", "job_cd", "perc_cd", "sex_cd", "age_cd", "category"], |
| 6: ["base_ym", "region_nm", "bzc_cd", "cp_cd", "enp_cd", "category"], |
| 9: ["region_nm", "base_ym", "region_count", "category"], |
| 10: ["region_nm", "base_ym", "region_count", "category"], |
| 11: ["region_nm", "base_ym", "industry_category", "category"], |
| } |
|
|
|
|
| |
| def norm_int_list(v): |
| if not isinstance(v, list): |
| return v |
| out = [] |
| for x in v: |
| try: |
| out.append(int(float(str(x).strip()))) |
| except Exception: |
| continue |
| return sorted(set(out)) |
|
|
|
|
| def norm_dict_of_lists(d): |
| """Normalize industry_select or bzc_cd: {str_key: [int, ...]}""" |
| if not isinstance(d, dict): |
| return d |
| return {str(k).upper() if len(str(k)) == 1 and str(k).isalpha() else str(k): |
| norm_int_list(arr) if isinstance(arr, list) else arr |
| for k, arr in d.items()} |
|
|
|
|
| def normalize(obj: Dict[str, Any], cat: int) -> Dict[str, Any]: |
| """Normalize prediction/gold for fair comparison (summary excluded).""" |
| o = dict(obj) |
| o.pop("summary", None) |
|
|
| for k in ["base_ym", "region_count", "category"]: |
| if k in o and isinstance(o[k], str): |
| try: |
| o[k] = int(o[k]) |
| except ValueError: |
| pass |
|
|
| for k in ["sex_cd", "age_cd", "job_cd", "perc_cd", "ptrn", |
| "industry_category", "cp_cd", "enp_cd"]: |
| if k in o: |
| o[k] = norm_int_list(o[k]) |
|
|
| if "flow_cd" in o and isinstance(o["flow_cd"], list): |
| o["flow_cd"] = norm_int_list(o["flow_cd"]) |
|
|
| for k in ["industry_select", "bzc_cd"]: |
| if k in o: |
| o[k] = norm_dict_of_lists(o[k]) |
|
|
| if "region_count" in o: |
| try: |
| o["region_count"] = max(1, min(10, int(o["region_count"]))) |
| except (ValueError, TypeError): |
| pass |
|
|
| return o |
|
|
|
|
| def extract_first_json(text: str): |
| start = text.find("{") |
| if start == -1: |
| return None |
| depth = 0 |
| for i in range(start, len(text)): |
| if text[i] == "{": |
| depth += 1 |
| elif text[i] == "}": |
| depth -= 1 |
| if depth == 0: |
| return text[start:i + 1] |
| return None |
|
|
|
|
| def compare(pred: Dict, gold: Dict, cat: int): |
| req = REQUIRED_KEYS.get(cat, []) |
| diff = {} |
| for k in req: |
| if pred.get(k, "<MISSING>") != gold.get(k, "<MISSING>"): |
| diff[k] = {"pred": pred.get(k), "gold": gold.get(k)} |
| return len(diff) == 0, diff |
|
|
|
|
| |
| def main(): |
| ap = argparse.ArgumentParser(description="DLM-NL2JSON-4B Evaluation") |
| ap.add_argument("--data", required=True, help="Test JSONL file path") |
| ap.add_argument("--base-url", default=None, help="OpenAI-compatible base URL") |
| ap.add_argument("--model", required=True, help="Model name") |
| ap.add_argument("--api-key", default=os.environ.get("OPENAI_API_KEY", ""), help="API key") |
| ap.add_argument("--disable-thinking", action="store_true", |
| help="Pass chat_template_kwargs to disable Qwen3 thinking mode") |
| ap.add_argument("--max-tokens", type=int, default=512) |
| ap.add_argument("--per-cat", type=int, default=999, help="Max samples per category") |
| args = ap.parse_args() |
|
|
| import openai |
| client = openai.OpenAI( |
| base_url=args.base_url or None, |
| api_key=args.api_key or "dummy", |
| timeout=60.0, |
| ) |
|
|
| |
| with open(args.data, encoding="utf-8") as f: |
| raw = [json.loads(line) for line in f] |
|
|
| |
| from collections import defaultdict |
| by_cat = defaultdict(list) |
| for item in raw: |
| out = item["output"] if isinstance(item["output"], dict) else json.loads(item["output"]) |
| cat = out["category"] |
| by_cat[cat].append({"input": item["input"], "gold": out}) |
|
|
| samples = [] |
| for cat in sorted(by_cat): |
| items = by_cat[cat][:args.per_cat] |
| samples.extend([(cat, ex) for ex in items]) |
|
|
| print(f"[INFO] Evaluating {len(samples)} samples across {len(by_cat)} categories\n") |
|
|
| |
| ok_counts, total_counts = Counter(), Counter() |
| latency_sums = Counter() |
|
|
| for idx, (cat, ex) in enumerate(samples, 1): |
| user_in = ex["input"].strip() |
| gold_norm = normalize(ex["gold"], cat) |
|
|
| tag, sys_prompt = TASK_MAP[cat] |
| messages = [ |
| {"role": "system", "content": sys_prompt}, |
| {"role": "user", "content": f"{tag}\n{user_in}"}, |
| ] |
|
|
| kwargs = dict(model=args.model, messages=messages, |
| max_tokens=args.max_tokens, temperature=0.0) |
| if args.disable_thinking: |
| kwargs["extra_body"] = {"chat_template_kwargs": {"enable_thinking": False}} |
|
|
| t0 = time.perf_counter() |
| try: |
| resp = client.chat.completions.create(**kwargs) |
| gen = resp.choices[0].message.content |
| except Exception as e: |
| dt = time.perf_counter() - t0 |
| total_counts[cat] += 1 |
| latency_sums[cat] += dt |
| print(f"[{idx:04d}] {CAT_NAMES.get(cat, cat)} | ERROR: {e}") |
| continue |
|
|
| dt = time.perf_counter() - t0 |
| total_counts[cat] += 1 |
| latency_sums[cat] += dt |
|
|
| json_str = extract_first_json(gen) or gen.strip() |
| try: |
| pred_obj = json.loads(json_str) |
| except json.JSONDecodeError: |
| print(f"[{idx:04d}] {CAT_NAMES.get(cat, cat)} | PARSE_FAIL | {dt:.2f}s") |
| continue |
|
|
| pred_norm = normalize(pred_obj, cat) |
| ok, diff = compare(pred_norm, gold_norm, cat) |
| if ok: |
| ok_counts[cat] += 1 |
|
|
| status = "OK" if ok else f"FAIL {list(diff.keys())}" |
| print(f"[{idx:04d}] {CAT_NAMES.get(cat, cat)} | {status} | {dt:.2f}s") |
|
|
| |
| print("\n" + "=" * 50) |
| print("EVALUATION SUMMARY") |
| print("=" * 50) |
| total_ok = total_all = 0 |
| for c in sorted(total_counts): |
| ok = ok_counts[c] |
| tot = total_counts[c] |
| acc = ok / tot if tot else 0 |
| avg_lat = latency_sums[c] / tot if tot else 0 |
| total_ok += ok |
| total_all += tot |
| print(f" {CAT_NAMES.get(c, c):20s}: {ok:4d}/{tot:4d} acc={acc:.1%} avg={avg_lat:.3f}s") |
|
|
| overall_acc = total_ok / total_all if total_all else 0 |
| overall_lat = sum(latency_sums.values()) / total_all if total_all else 0 |
| print(f" {'OVERALL':20s}: {total_ok:4d}/{total_all:4d} acc={overall_acc:.1%} avg={overall_lat:.3f}s") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|