| import json |
| import os |
| import sqlite3 |
| from datasets import Dataset |
| from transformers import T5Tokenizer |
|
|
| |
| |
| |
| BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
|
| TRAIN_JSON = os.path.join(BASE_DIR, "data", "train_spider.json") |
| DEV_JSON = os.path.join(BASE_DIR, "data", "dev.json") |
| DB_FOLDER = os.path.join(BASE_DIR, "data", "database") |
|
|
| SAVE_TRAIN = os.path.join(BASE_DIR, "data", "tokenized", "train") |
| SAVE_DEV = os.path.join(BASE_DIR, "data", "tokenized", "validation") |
|
|
| os.makedirs(os.path.dirname(SAVE_TRAIN), exist_ok=True) |
|
|
| print("Project root:", BASE_DIR) |
| print("Train file:", TRAIN_JSON) |
| print("Database folder:", DB_FOLDER) |
|
|
| |
| |
| |
| tokenizer = T5Tokenizer.from_pretrained("t5-small") |
|
|
| |
| |
| |
| def get_schema(db_path): |
| conn = sqlite3.connect(db_path) |
| cursor = conn.cursor() |
|
|
| tables = cursor.execute( |
| "SELECT name FROM sqlite_master WHERE type='table';" |
| ).fetchall() |
|
|
| schema_text = [] |
|
|
| for table in tables: |
| table = table[0] |
|
|
| columns = cursor.execute(f"PRAGMA table_info({table});").fetchall() |
| col_names = [c[1] for c in columns] |
|
|
| schema_text.append(f"{table}({', '.join(col_names)})") |
|
|
| conn.close() |
| return "\n".join(schema_text) |
|
|
|
|
| |
| |
| |
| def build_examples(spider_json): |
|
|
| print(f"\nBuilding dataset from: {spider_json}") |
|
|
| data = json.load(open(spider_json)) |
|
|
| inputs = [] |
| outputs = [] |
|
|
| for ex in data: |
|
|
| question = ex["question"] |
| sql = ex["query"] |
| db_id = ex["db_id"] |
|
|
| db_path = os.path.join(DB_FOLDER, db_id, f"{db_id}.sqlite") |
|
|
| |
| if not os.path.exists(db_path): |
| continue |
|
|
| schema = get_schema(db_path) |
|
|
| |
| input_text = f"""Database Schema: |
| {schema} |
| |
| Translate English to SQL: |
| {question} |
| SQL: |
| """ |
|
|
| inputs.append(input_text) |
| outputs.append(sql) |
|
|
| return Dataset.from_dict({"input": inputs, "output": outputs}) |
|
|
|
|
| |
| |
| |
| def tokenize(example): |
|
|
| model_input = tokenizer( |
| example["input"], |
| max_length=512, |
| padding="max_length", |
| truncation=True |
| ) |
|
|
| label = tokenizer( |
| example["output"], |
| max_length=256, |
| padding="max_length", |
| truncation=True |
| ) |
|
|
| model_input["labels"] = label["input_ids"] |
| return model_input |
|
|
|
|
| |
| |
| |
| print("\nBuilding TRAIN dataset...") |
| train_dataset = build_examples(TRAIN_JSON) |
|
|
| print("Tokenizing TRAIN dataset...") |
| tokenized_train = train_dataset.map(tokenize, batched=False) |
|
|
| print("Saving TRAIN dataset...") |
| tokenized_train.save_to_disk(SAVE_TRAIN) |
|
|
|
|
| print("\nBuilding VALIDATION dataset...") |
| val_dataset = build_examples(DEV_JSON) |
|
|
| print("Tokenizing VALIDATION dataset...") |
| tokenized_val = val_dataset.map(tokenize, batched=False) |
|
|
| print("Saving VALIDATION dataset...") |
| tokenized_val.save_to_disk(SAVE_DEV) |
|
|
| print("\nDONE ✔ Dataset prepared successfully!") |
| print("Train saved at:", SAVE_TRAIN) |
| print("Validation saved at:", SAVE_DEV) |