| |
| |
| |
| |
| |
|
|
| from execution_reward import execution_reward |
| import os, gc, json, random, torch |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| from trl import PPOTrainer, PPOConfig |
| from trl.models.modeling_value_head import AutoModelForSeq2SeqLMWithValueHead |
| from peft import LoraConfig, get_peft_model |
|
|
| |
| os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| device = "mps" if torch.backends.mps.is_available() else "cpu" |
| print("Using device:", device) |
|
|
| os.makedirs("rlhf_text2sql_lora", exist_ok=True) |
|
|
| |
| model_name = "google/flan-t5-small" |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
| |
| lora_config = LoraConfig( |
| r=8, |
| lora_alpha=16, |
| target_modules=["q","v"], |
| lora_dropout=0.05, |
| bias="none", |
| task_type="SEQ_2_SEQ_LM", |
| ) |
|
|
| base_model = get_peft_model(base_model, lora_config) |
|
|
| model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(base_model).to(device) |
| ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(model_name).to(device) |
|
|
| model.config.use_cache = False |
| ref_model.config.use_cache = False |
|
|
| |
| with open("data/train_spider.json") as f: |
| dataset = json.load(f) |
|
|
| def build_prompt(example): |
| return f"Translate to SQL: {example['question']}" |
|
|
| |
| ppo_config = PPOConfig( |
| batch_size=1, |
| mini_batch_size=1, |
| learning_rate=2e-6, |
| target_kl=0.05, |
| adap_kl_ctrl=True, |
| init_kl_coef=0.2, |
| ) |
|
|
| ppo_trainer = PPOTrainer( |
| config=ppo_config, |
| model=model, |
| ref_model=ref_model, |
| tokenizer=tokenizer, |
| ) |
|
|
| |
| def generate_sql(query_tensors): |
|
|
| |
| with torch.no_grad(): |
| response_tensors = ppo_trainer.generate( |
| query_tensors, |
| max_new_tokens=64, |
|
|
| |
| do_sample=False, |
|
|
| |
| num_beams=1, |
| early_stopping=True, |
|
|
| |
| pad_token_id=tokenizer.eos_token_id, |
| ) |
|
|
| |
| cleaned = [] |
| for t in response_tensors: |
| t = torch.nan_to_num(t, nan=0, posinf=0, neginf=0) |
| cleaned.append(t) |
|
|
| return cleaned |
|
|
| |
| MAX_STEPS = 1200 |
|
|
| for step in range(MAX_STEPS): |
|
|
| |
| example = random.choice(dataset) |
|
|
| question = example["question"] |
| gold_sql = example["query"] |
| db_id = example["db_id"] |
| db_path = f"data/database/{db_id}/{db_id}.sqlite" |
|
|
| |
| enc = tokenizer(build_prompt(example), return_tensors="pt") |
| query_tensor = enc.input_ids.to(device) |
| query_tensors = [query_tensor[0]] |
|
|
| |
| response_tensors = generate_sql(query_tensors) |
| pred_sql = tokenizer.decode(response_tensors[0], skip_special_tokens=True) |
|
|
| |
| reward = execution_reward(pred_sql, gold_sql, db_path) |
| reward_tensor = torch.tensor([reward], dtype=torch.float32).to(device) |
|
|
| |
| stats = ppo_trainer.step(query_tensors, response_tensors, [reward_tensor]) |
|
|
| |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
| |
| del query_tensor, response_tensors, reward_tensor |
| gc.collect() |
| if device == "mps": |
| torch.mps.empty_cache() |
|
|
| |
| if step % 20 == 0: |
| print(f"\nStep {step}/{MAX_STEPS}") |
| print("DB:", db_id) |
| print("Q:", question) |
| print("Pred:", pred_sql) |
| print("Gold:", gold_sql) |
| print("Reward:", reward) |
|
|
| |
| model.save_pretrained("rlhf_text2sql_lora") |
| tokenizer.save_pretrained("rlhf_text2sql_lora") |
|
|
| print("\nTraining complete โ model saved!") |