| | import logging |
| | import os |
| | import pickle |
| | import time |
| | import json |
| | from dataclasses import dataclass |
| | from typing import Any, Dict, List, NewType, Tuple |
| | from tqdm import tqdm |
| |
|
| |
|
| | import torch |
| | from torch.utils.data.dataset import Dataset |
| | from transformers.tokenization_utils import PreTrainedTokenizer |
| | from transformers.data.data_collator import DataCollator |
| | from transformers.tokenization_bart import BartTokenizer |
| | from transformers.tokenization_roberta import RobertaTokenizer |
| | from relogic.pretrainkit.datasets.utils import pad_and_tensorize_sequence |
| | logger = logging.getLogger(__name__) |
| |
|
| | label_mapping = json.load(open("data/preprocessed_data/bart_parser_label_mapping_2.json")) |
| |
|
| | class QuerySchemaRelation2SQLDataset(Dataset): |
| | """ |
| | Dataset for relation-aware text-to-SQL: query + schema + relation -> SQL |
| | """ |
| | def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, local_rank=-1): |
| | self.examples = [] |
| | self.keywords = label_mapping["keyword"] |
| | self.label_eos_id = self.keywords.index(label_mapping["label_eos_token"]) |
| | self.label_bos_id = self.keywords.index(label_mapping["label_bos_token"]) |
| | add_prefix_space = isinstance(tokenizer, BartTokenizer) or isinstance(tokenizer, RobertaTokenizer) |
| | total, valid = 0, 0 |
| | with open(file_path, encoding="utf-8") as f: |
| | for line in tqdm(f): |
| | total += 1 |
| | example = json.loads(line) |
| | text = example["normalized_question"] |
| | columns = example["columns"] |
| | tables = example["tables"] |
| | columns_text = example["column_text"] |
| | tables_text = example["table_text"] |
| | sql = example["sql"] |
| | |
| | token_idx_to_sub_token_start_idx = {} |
| | text_tokens = [tokenizer.cls_token] |
| | start_idx = 0 |
| | for idx, token in enumerate(text.split()): |
| | sub_tokens = tokenizer.tokenize(token, add_prefix_space=add_prefix_space) |
| | token_idx_to_sub_token_start_idx[idx] = start_idx |
| | text_tokens.extend(sub_tokens) |
| | start_idx += len(sub_tokens) |
| | text_tokens.append(tokenizer.sep_token) |
| | question_start, question_end = 1, len(text_tokens) - 1 |
| |
|
| | column_spans = [] |
| | start_idx = len(text_tokens) |
| | for column_tokens in columns_text: |
| | column_str = " ".join(column_tokens) |
| | column_tokens = tokenizer.tokenize(column_str, add_prefix_space=add_prefix_space) |
| | text_tokens.extend(column_tokens) |
| | text_tokens.append(tokenizer.sep_token) |
| | end_idx = start_idx + len(column_tokens) |
| | column_spans.append((start_idx, end_idx)) |
| | start_idx = end_idx + 1 |
| |
|
| | column_start = [column_span[0] for column_span in column_spans] |
| | column_end = [column_span[1] for column_span in column_spans] |
| |
|
| | table_spans = [] |
| | start_idx = len(text_tokens) |
| | for table_tokens in tables_text: |
| | table_str = " ".join(table_tokens) |
| | table_tokens = tokenizer.tokenize(table_str, add_prefix_space=add_prefix_space) |
| | text_tokens.extend(table_tokens) |
| | text_tokens.append(tokenizer.sep_token) |
| | end_idx = start_idx + len(table_tokens) |
| | table_spans.append((start_idx, end_idx)) |
| | start_idx = end_idx + 1 |
| |
|
| | table_start = [table_span[0] for table_span in table_spans] |
| | table_end = [table_span[1] for table_span in table_spans] |
| |
|
| | input_ids = tokenizer.convert_tokens_to_ids(text_tokens) |
| |
|
| | if len(input_ids) > block_size: |
| | continue |
| |
|
| | label_ids = [] |
| | try: |
| | for token in sql.split(): |
| | if token in columns: |
| | label_ids.append(columns.index(token) + len(self.keywords)) |
| | else: |
| | label_ids.append(self.keywords.index(token)) |
| | except: |
| | continue |
| |
|
| | label_ids = [self.label_bos_id] + label_ids + [self.label_eos_id] |
| |
|
| | primary_key = [int(x) for x in example["sc_struct"]["primary_key"]] |
| | foreign_key = {x.split(",")[0]: int(x.split(",")[1]) for x in example["sc_struct"]["foreign_key"]} |
| | column_to_table = {"0": None} |
| |
|
| | sc_link = {"q_col_match": {}, "q_tab_match": {}} |
| | for k, v in example["sc_link"]["q_col_match"].items(): |
| | new_k = str(token_idx_to_sub_token_start_idx[int(k.split(",")[0])]) + "," + k.split(",")[1] |
| | sc_link["q_col_match"][new_k] = v |
| |
|
| | for k, v in example["sc_link"]["q_tab_match"].items(): |
| | new_k = str(token_idx_to_sub_token_start_idx[int(k.split(",")[0])]) + "," + k.split(",")[1] |
| | sc_link["q_tab_match"][new_k] = v |
| |
|
| | cv_link = {"num_date_match": {}, "cell_match": {}} |
| | for k, v in example["cv_link"]["num_date_match"].items(): |
| | new_k = str(token_idx_to_sub_token_start_idx[int(k.split(",")[0])]) + "," + k.split(",")[1] |
| | cv_link["num_date_match"][new_k] = v |
| | for k, v in example["cv_link"]["cell_match"].items(): |
| | new_k = str(token_idx_to_sub_token_start_idx[int(k.split(",")[0])]) + "," + k.split(",")[1] |
| | cv_link["cell_match"][new_k] = v |
| |
|
| |
|
| | for idx, column in enumerate(columns): |
| | if column == "*": |
| | continue |
| | t = column.split(".")[0] |
| | column_to_table[str(idx)] = tables.index(t) |
| |
|
| | foreign_keys_tables = {} |
| | for k, v in foreign_key.items(): |
| | t_k = str(column_to_table[str(k)]) |
| | t_v = str(column_to_table[str(v)]) |
| | if t_k not in foreign_keys_tables: |
| | foreign_keys_tables[t_k] = [] |
| | if int(t_v) not in foreign_keys_tables[t_k]: |
| | foreign_keys_tables[t_k].append(int(t_v)) |
| |
|
| | self.examples.append({ |
| | "input_ids": input_ids, |
| | "example_info": { |
| | "normalized_question": text, |
| | "columns": columns, |
| | "tables": tables, |
| | "tokens": text_tokens, |
| | "question_start": question_start, |
| | "question_end": question_end, |
| | "column_start": torch.LongTensor(column_start), |
| | "column_end": torch.LongTensor(column_end), |
| | "table_start": torch.LongTensor(table_start), |
| | "table_end": torch.LongTensor(table_end), |
| | "sc_link": sc_link, |
| | "cv_link": cv_link, |
| | "primary_keys": primary_key, |
| | "foreign_keys": foreign_key, |
| | "column_to_table": column_to_table, |
| | "foreign_keys_tables": foreign_keys_tables |
| | }, |
| | "column_spans": column_spans, |
| | "label_ids": label_ids}) |
| | valid += 1 |
| | print("Valid Example {}; Invalid Example {}".format(valid, total - valid)) |
| |
|
| | def __len__(self): |
| | return len(self.examples) |
| |
|
| | def __getitem__(self, i): |
| | return self.examples[i] |
| |
|
| |
|
| | @dataclass |
| | class DataCollatorForQuerySchemaRelation2SQL: |
| | """ |
| | Data collator used for query + schema -> sql modeling. |
| | """ |
| | tokenizer: PreTrainedTokenizer |
| | label_padding_id = label_mapping["keyword"].index(label_mapping["label_padding_token"]) |
| | label_eos_id = label_mapping["keyword"].index(label_mapping["label_eos_token"]) |
| | label_bos_id = label_mapping["keyword"].index(label_mapping["label_bos_token"]) |
| | def collate_batch(self, examples) -> Dict[str, torch.Tensor]: |
| |
|
| | input_ids_sequences = [example["input_ids"] for example in examples] |
| | column_spans_sequences = [example["column_spans"] for example in examples] |
| | label_ids_sequences = [example["label_ids"] for example in examples] |
| | padded_input_ids_tensor = pad_and_tensorize_sequence( |
| | input_ids_sequences, padding_value=self.tokenizer.pad_token_id) |
| | padded_column_spans_tensor = pad_and_tensorize_sequence( |
| | column_spans_sequences, padding_value=(0, 1)) |
| |
|
| | example_info_list = [] |
| | for example in examples: |
| | example_info_list.append(example["example_info"]) |
| | label_ids_tensor = pad_and_tensorize_sequence( |
| | label_ids_sequences, padding_value=self.label_padding_id) |
| | return { |
| | "input_ids": padded_input_ids_tensor, |
| | "column_spans": padded_column_spans_tensor, |
| | "labels": label_ids_tensor, |
| | "example_info_list": example_info_list, |
| | "input_padding_id": self.tokenizer.pad_token_id, |
| | "label_padding_id": self.label_padding_id, |
| | "label_eos_id": self.label_eos_id, |
| | "label_bos_id": self.label_bos_id |
| | } |
| |
|
| |
|