| --- |
| license: mit |
| datasets: |
| - b-mc2/sql-create-context |
| - gretelai/synthetic_text_to_sql |
| language: |
| - en |
| base_model: google-t5/t5-base |
| metrics: |
| - exact_match |
| model-index: |
| - name: juanfra218/text2sql |
| results: |
| - task: |
| type: text-to-sql |
| metrics: |
| - name: exact_match |
| type: exact_match |
| value: 0.4326836917562724 |
| - name: bleu |
| type: bleu |
| value: 0.6687 |
| tags: |
| - sql |
| library_name: transformers |
| --- |
| |
| # Fine-Tuned Google T5 Model for Text to SQL Translation |
|
|
| A fine-tuned version of the Google T5 model, trained for the task of translating natural language queries into SQL statements. |
|
|
| ## Model Details |
|
|
| - **Architecture**: Google T5 Base (Text-to-Text Transfer Transformer) |
| - **Task**: Text to SQL Translation |
| - **Fine-Tuning Datasets**: |
| - [sql-create-context Dataset](https://huggingface.co/datasets/b-mc2/sql-create-context) |
| - [Synthetic-Text-To-SQL Dataset](https://huggingface.co/datasets/gretelai/synthetic-text-to-sql) |
|
|
| ## Training Parameters |
|
|
| ``` |
| training_args = Seq2SeqTrainingArguments( |
| output_dir="./results", |
| evaluation_strategy="epoch", |
| learning_rate=2e-5, |
| per_device_train_batch_size=8, |
| per_device_eval_batch_size=8, |
| weight_decay=0.01, |
| save_total_limit=3, |
| num_train_epochs=3, |
| predict_with_generate=True, |
| fp16=True, |
| push_to_hub=False, |
| ) |
| ``` |
|
|
| ## Usage |
|
|
| ``` |
| import torch |
| from transformers import T5Tokenizer, T5ForConditionalGeneration |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| # Load the tokenizer and model |
| model_path = 'juanfra218/text2sql' |
| tokenizer = T5Tokenizer.from_pretrained(model_path) |
| model = T5ForConditionalGeneration.from_pretrained(model_path) |
| model.to(device) |
| |
| # Function to generate SQL queries |
| def generate_sql(prompt, schema): |
| input_text = "translate English to SQL: " + prompt + " " + schema |
| inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True, padding="max_length") |
| |
| inputs = {key: value.to(device) for key, value in inputs.items()} |
| |
| max_output_length = 1024 |
| outputs = model.generate(**inputs, max_length=max_output_length) |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| # Interactive loop |
| print("Enter 'quit' to exit.") |
| while True: |
| prompt = input("Insert prompt: ") |
| schema = input("Insert schema: ") |
| if prompt.lower() == 'quit': |
| break |
| |
| sql_query = generate_sql(prompt, schema) |
| print(f"Generated SQL query: {sql_query}") |
| print() |
| ``` |
|
|
| ## Files |
|
|
| - `optimizer.pt`: State of the optimizer. |
| - `training_args.bin`: Training arguments and hyperparameters. |
| - `tokenizer.json`: Tokenizer vocabulary and settings. |
| - `spiece.model`: SentencePiece model file. |
| - `special_tokens_map.json`: Special tokens mapping. |
| - `tokenizer_config.json`: Tokenizer configuration settings. |
| - `model.safetensors`: Trained model weights. |
| - `generation_config.json`: Configuration for text generation. |
| - `config.json`: Model architecture configuration. |
| - `test_results.csv`: Results on the testing set, contains: prompt, context, true_answer, predicted_answer, exact_match |