text2sql_tani / src /evaluate_without_constraied.py
tjhalanigrid's picture
Added full project
cf17729
# *********** code till task 3 ************
# import json
# import subprocess
# import sys
# import argparse
# import random
# import sqlite3
# import time
# import re
# import os
# from pathlib import Path
# import torch
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# from peft import PeftModel
# from prompting import encode_prompt
# # -------------------------------
# # NORMALIZATION
# # -------------------------------
# def normalize_sql(sql):
# sql = sql.replace('"', "'")
# sql = re.sub(r"\s+", " ", sql)
# return sql.strip().lower().rstrip(";")
# # -------------------------------
# # ๐Ÿ”ฅ SAFE RESULT NORMALIZATION (FIX)
# # -------------------------------
# def normalize_result(res):
# try:
# return sorted([str(r) for r in res])
# except:
# return []
# # -------------------------------
# # EXECUTION CHECK (FIXED)
# # -------------------------------
# def check_execution(pred_sql, gold_sql, db_path):
# try:
# conn = sqlite3.connect(db_path)
# conn.text_factory = lambda b: b.decode(errors='ignore')
# start_time = time.monotonic()
# def timeout_handler():
# return 1 if (time.monotonic() - start_time) > 2.0 else 0
# conn.set_progress_handler(timeout_handler, 10000)
# cursor = conn.cursor()
# cursor.execute(pred_sql)
# pred_res = cursor.fetchall()
# cursor.execute(gold_sql)
# gold_res = cursor.fetchall()
# conn.close()
# # ๐Ÿ”ฅ FIXED COMPARISON
# return normalize_result(pred_res) == normalize_result(gold_res)
# except Exception:
# return False
# # -------------------------------
# # SPIDER PARSER
# # -------------------------------
# def _parse_spider_accuracy(stdout: str, metric_type: str):
# for line in stdout.splitlines():
# if metric_type == "exec" and line.strip().startswith("execution"):
# try:
# return float(line.split()[-1])
# except:
# pass
# elif metric_type == "match" and line.strip().startswith("exact"):
# try:
# return float(line.split()[-1])
# except:
# pass
# return None
# # -------------------------------
# # MAIN
# # -------------------------------
# def main():
# parser = argparse.ArgumentParser()
# parser.add_argument("--adapter", type=str, required=True)
# parser.add_argument("--num_samples", type=int, default= 500)
# parser.add_argument("--shuffle_dev", action="store_true")
# parser.add_argument("--shuffle_seed", type=int, default=42)
# args = parser.parse_args()
# project_root = Path(__file__).resolve().parents[1]
# adapter_dir = project_root / args.adapter
# db_root = project_root / "data" / "database"
# table_json = project_root / "data" / "tables.json"
# dev_json = project_root / "data" / "dev.json"
# pred_path = project_root / "temp_predictions.txt"
# temp_gold_path = project_root / "temp_gold.sql"
# if not adapter_dir.exists():
# raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
# device = "mps" if torch.backends.mps.is_available() else (
# "cuda" if torch.cuda.is_available() else "cpu"
# )
# print(f"Using device: {device}")
# BASE_MODEL = "Salesforce/codet5-base"
# tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
# if tokenizer.pad_token is None:
# tokenizer.pad_token = tokenizer.eos_token
# print(f"\n๐Ÿ“ฆ Loading Model: {args.adapter}")
# base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
# adapter_for_peft = os.path.relpath(adapter_dir, project_root)
# model = PeftModel.from_pretrained(
# base,
# adapter_for_peft,
# local_files_only=True
# ).to(device)
# model = model.merge_and_unload()
# model.eval()
# # -------------------------------
# # LOAD DATA
# # -------------------------------
# with dev_json.open() as f:
# dev = json.load(f)
# if args.shuffle_dev:
# rng = random.Random(args.shuffle_seed)
# rng.shuffle(dev)
# dev = dev[: args.num_samples]
# total = len(dev)
# gen_kwargs = dict(
# max_new_tokens=160,
# num_beams=8,
# length_penalty=0.8,
# do_sample=False,
# early_stopping=True,
# pad_token_id=tokenizer.pad_token_id,
# eos_token_id=tokenizer.eos_token_id,
# )
# print(f"\n๐Ÿš€ Evaluating {total} samples...\n")
# em_correct = 0
# ex_correct = 0
# with pred_path.open("w") as out_pred, temp_gold_path.open("w") as out_gold, torch.no_grad():
# for i, ex in enumerate(dev, start=1):
# db_id = ex["db_id"]
# question = ex["question"]
# gold_query = ex["query"]
# db_path = db_root / db_id / f"{db_id}.sqlite"
# # -------------------------------
# # GENERATE SQL
# # -------------------------------
# input_ids = encode_prompt(
# tokenizer,
# question,
# db_id,
# device=device,
# max_input_tokens=512
# )
# input_ids = input_ids.unsqueeze(0).to(device)
# attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
# outputs = model.generate(
# input_ids=input_ids,
# attention_mask=attention_mask,
# **gen_kwargs
# )
# pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
# # -------------------------------
# # SAVE FOR SPIDER EVAL
# # -------------------------------
# out_pred.write(f"{pred_sql}\n")
# out_gold.write(f"{gold_query}\t{db_id}\n")
# # -------------------------------
# # LIVE METRICS
# # -------------------------------
# if normalize_sql(pred_sql) == normalize_sql(gold_query):
# em_correct += 1
# if check_execution(pred_sql, gold_query, db_path):
# ex_correct += 1
# if i % 20 == 0 or i == total:
# print(
# f"Progress: {i}/{total} | "
# f"EM: {(em_correct/i)*100:.2f}% | "
# f"EX: {(ex_correct/i)*100:.2f}%"
# )
# print("\n๐Ÿš€ Running Official Spider Evaluation...\n")
# eval_script = project_root / "spider_eval" / "evaluation.py"
# # EXACT MATCH
# cmd_match = [
# sys.executable, str(eval_script),
# "--gold", str(temp_gold_path),
# "--pred", str(pred_path),
# "--etype", "match",
# "--db", str(db_root),
# "--table", str(table_json),
# ]
# proc_match = subprocess.run(cmd_match, capture_output=True, text=True)
# exact_acc = _parse_spider_accuracy(proc_match.stdout, "match")
# # EXECUTION
# cmd_exec = [
# sys.executable, str(eval_script),
# "--gold", str(temp_gold_path),
# "--pred", str(pred_path),
# "--etype", "exec",
# "--db", str(db_root),
# "--table", str(table_json),
# ]
# proc_exec = subprocess.run(cmd_exec, capture_output=True, text=True)
# exec_acc = _parse_spider_accuracy(proc_exec.stdout, "exec")
# print("==========================================")
# print(f"๐ŸŽฏ OFFICIAL SPIDER RESULTS FOR: {args.adapter}")
# print("==========================================")
# print(f"Exact Match Accuracy : {exact_acc*100:.2f}%" if exact_acc else "EM parsing failed")
# print(f"Execution Accuracy : {exec_acc*100:.2f}%" if exec_acc else "EX parsing failed")
# print("==========================================\n")
# if __name__ == "__main__":
# main()
# *********** for task 2 ****************************************
import json
import argparse
import random
import sqlite3
import re
import os
from pathlib import Path
from collections import defaultdict
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
from prompting import encode_prompt
# -------------------------------
# NORMALIZATION
# -------------------------------
def normalize_sql(sql):
sql = sql.replace('"', "'")
sql = re.sub(r"\s+", " ", sql)
return sql.strip().lower().rstrip(";")
def normalize_result(res):
try:
return sorted([str(r) for r in res])
except:
return []
# -------------------------------
# STEP 1: EXECUTION
# -------------------------------
def execute_with_error(sql, db_path):
try:
conn = sqlite3.connect(db_path)
cur = conn.cursor()
cur.execute(sql)
res = cur.fetchall()
conn.close()
return res, None
except Exception as e:
return None, str(e)
# -------------------------------
# STEP 2: ERROR CLASSIFICATION
# -------------------------------
def classify_error(sql, error_msg):
if error_msg is None:
return "correct"
err = error_msg.lower()
sql_l = sql.lower()
if "syntax" in err:
return "syntax_error"
if "no such table" in err:
return "wrong_table"
if "no such column" in err:
return "wrong_column"
if "ambiguous" in err:
return "missing_join"
if "datatype mismatch" in err:
return "type_error"
if "where" not in sql_l and any(x in sql_l for x in ["=", ">", "<"]):
return "missing_where"
return "other"
# -------------------------------
# STEP 4: HINTS
# -------------------------------
def generate_hint(error_type):
hints = {
"missing_join": "Try using JOIN between related tables.",
"wrong_column": "Check column names in schema.",
"missing_where": "Add WHERE condition.",
"syntax_error": "Fix SQL syntax.",
"wrong_table": "Verify table names.",
"type_error": "Check data types.",
"other": "Review SQL logic."
}
return hints.get(error_type, "")
# -------------------------------
# STEP 2 EXTRA: LIGHT ATTRIBUTION
# -------------------------------
def extract_keywords(question):
return [w for w in re.findall(r"\w+", question.lower()) if len(w) > 3]
# -------------------------------
# MAIN
# -------------------------------
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--adapter", type=str, required=True)
parser.add_argument("--num_samples", type=int, default=200)
args = parser.parse_args()
project_root = Path(__file__).resolve().parents[1]
db_root = project_root / "data" / "database"
dev_json = project_root / "data" / "dev.json"
device = "mps" if torch.backends.mps.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base")
base = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5-base").to(device)
model = PeftModel.from_pretrained(
base,
os.path.relpath(project_root / args.adapter, project_root),
local_files_only=True
).to(device)
model = model.merge_and_unload()
model.eval()
with open(dev_json) as f:
dev = json.load(f)
dev = dev[:args.num_samples]
# STORAGE
error_counter = defaultdict(int)
error_examples = defaultdict(list)
success_examples = []
hint_examples = defaultdict(list)
operation_counter = defaultdict(int)
attribution_map = defaultdict(list)
em, ex = 0, 0
print(f"\n๐Ÿš€ Evaluating {len(dev)} samples...\n")
for i, sample in enumerate(dev, 1):
db_id = sample["db_id"]
q = sample["question"]
gold = sample["query"]
db_path = db_root / db_id / f"{db_id}.sqlite"
input_ids = encode_prompt(tokenizer, q, db_id, device=device).unsqueeze(0)
out = model.generate(input_ids=input_ids, max_new_tokens=120, num_beams=8)
pred = tokenizer.decode(out[0], skip_special_tokens=True).strip()
# operation analysis
s = pred.lower()
if "select" in s: operation_counter["SELECT"] += 1
if "where" in s: operation_counter["WHERE"] += 1
if "join" in s: operation_counter["JOIN"] += 1
if "group by" in s: operation_counter["GROUP_BY"] += 1
if "order by" in s: operation_counter["ORDER_BY"] += 1
pred_res, err = execute_with_error(pred, db_path)
gold_res, _ = execute_with_error(gold, db_path)
error_type = classify_error(pred, err)
error_counter[error_type] += 1
# attribution
if err:
attribution_map[error_type].append(extract_keywords(q))
# examples
if len(error_examples[error_type]) < 3:
error_examples[error_type].append(pred)
# hints
if error_type != "correct":
hint = generate_hint(error_type)
if len(hint_examples[error_type]) < 3:
hint_examples[error_type].append((pred, hint))
# metrics
if normalize_sql(pred) == normalize_sql(gold):
em += 1
if pred_res and gold_res and normalize_result(pred_res) == normalize_result(gold_res):
ex += 1
if len(success_examples) < 5:
success_examples.append(pred)
if i % 20 == 0:
print(f"[{i}] EM: {em/i:.2f} | EX: {ex/i:.2f}")
# -------------------------------
# OUTPUT
# -------------------------------
print("\n๐ŸŽฏ FINAL RESULTS")
print(f"EM: {em/len(dev)*100:.2f}%")
print(f"EX: {ex/len(dev)*100:.2f}%")
print("\n๐Ÿ”ฅ ERROR SUMMARY")
for k, v in error_counter.items():
print(k, ":", v)
print("\n๐Ÿ”ฅ ERROR EXAMPLES")
for k in error_examples:
print("\n", k)
for e in error_examples[k]:
print(" ", e)
print("\n๐Ÿ”ฅ HINTS")
for k in hint_examples:
print("\n", k)
for sql, h in hint_examples[k]:
print(" ", sql)
print(" โ†’", h)
print("\n๐Ÿ”ฅ ATTRIBUTION (KEYWORDS)")
for k in attribution_map:
print(k, ":", attribution_map[k][:3])
print("\n๐Ÿ”ฅ SQL OPERATIONS")
for k, v in operation_counter.items():
print(k, ":", v)
# -------------------------------
# ADVERSARIAL
# -------------------------------
print("\n๐Ÿ”ฅ ADVERSARIAL TESTS")
adv = [
"Find most expensive product",
"Top 3 students by marks",
"Average salary per department"
]
for q in adv:
inp = encode_prompt(tokenizer, q, dev[0]["db_id"], device=device).unsqueeze(0)
out = model.generate(input_ids=inp, max_new_tokens=120)
print("\nQ:", q)
print("SQL:", tokenizer.decode(out[0], skip_special_tokens=True))
if __name__ == "__main__":
main()