gap-text2sql
/
gap-text2sql-main
/mrat-sql-gap
/seq2struct
/datasets
/spider_lib
/preprocess
/parse_raw_json.py
| import os, sys | |
| import json | |
| import sqlite3 | |
| import traceback | |
| import argparse | |
| import tqdm | |
| from seq2struct.datasets.spider_lib.process_sql import get_sql | |
| from seq2struct.datasets.spider_lib.preprocess.schema import Schema, get_schemas_from_json | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--input', required=True) | |
| parser.add_argument('--tables', required=True) | |
| parser.add_argument('--output', required=True) | |
| args = parser.parse_args() | |
| sql_path = args.input | |
| output_file = args.output | |
| table_file = args.tables | |
| schemas, db_names, tables = get_schemas_from_json(table_file) | |
| with open(sql_path, encoding='utf8') as inf: | |
| sql_data = json.load(inf) | |
| sql_data_new = [] | |
| for data in tqdm.tqdm(sql_data): | |
| try: | |
| db_id = data["db_id"] | |
| schema = schemas[db_id] | |
| table = tables[db_id] | |
| schema = Schema(schema, table) | |
| sql = data["query"] | |
| sql_label = get_sql(schema, sql) | |
| data["sql"] = sql_label | |
| sql_data_new.append(data) | |
| except: | |
| print("db_id: ", db_id) | |
| print("sql: ", sql) | |
| raise | |
| with open(output_file, 'wt', encoding='utf8') as out: | |
| json.dump(sql_data_new, out, sort_keys=True, indent=4, separators=(',', ': '), ensure_ascii=False) | |