| |
| |
| |
|
|
| import torch |
| import torch.nn.functional as F |
| from datasets import load_dataset |
| from transformers import AutoTokenizer |
| from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList |
| from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead |
| from peft import PeftModel |
| import os, sys, sqlite3, re, random |
|
|
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
| from execution_reward import execution_reward, extract_tables, extract_columns |
| try: |
| import sqlparse |
| except Exception: |
| sqlparse = None |
|
|
|
|
| |
| |
| |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
| device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu") |
| print("Using device:", device) |
|
|
|
|
| |
| |
| |
| NUM_EPOCHS = 5 |
| LOG_EVERY = 20 |
| USE_SCHEMA = True |
| SCHEMA_WARMUP_EPOCHS = 0 |
| MAX_SCHEMA_CHARS = 1500 |
| MAX_OUTPUT_TOKENS = 80 |
| ROLLOUTS_PER_EPOCH = 2048 |
|
|
|
|
| |
| |
| |
| PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
|
| |
| RL_MODEL_PATH = os.path.join(PROJECT_ROOT, "checkpoints", "rlhf_t5_best") |
| output_dir = RL_MODEL_PATH |
|
|
| DB_ROOT = os.path.join(PROJECT_ROOT, "data/database") |
|
|
| |
| ADAPTER_PATH = os.path.join(PROJECT_ROOT, "checkpoints/sft_t5") |
|
|
| FALLBACK_ADAPTER_PATH = os.path.join(PROJECT_ROOT, "models/t5_spider_sft_lora") |
| FALLBACK_ADAPTER_PATH_2 = os.path.join(PROJECT_ROOT, "outputs/sft_text2sql") |
| |
| BASE_MODEL = os.environ.get("BASE_MODEL", "t5-small") |
|
|
|
|
| |
| |
| |
| print("Loading base:", BASE_MODEL) |
| if not os.path.isdir(ADAPTER_PATH): |
| if os.path.isdir(FALLBACK_ADAPTER_PATH): |
| ADAPTER_PATH = FALLBACK_ADAPTER_PATH |
| elif os.path.isdir(FALLBACK_ADAPTER_PATH_2): |
| ADAPTER_PATH = FALLBACK_ADAPTER_PATH_2 |
| print("Loading adapters:", ADAPTER_PATH) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) |
|
|
| model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(BASE_MODEL).to(device) |
| model.pretrained_model = PeftModel.from_pretrained(model.pretrained_model, ADAPTER_PATH) |
|
|
| ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(BASE_MODEL).to(device) |
| ref_model.pretrained_model = PeftModel.from_pretrained(ref_model.pretrained_model, ADAPTER_PATH) |
|
|
| ref_model.eval() |
| for p in ref_model.parameters(): |
| p.requires_grad_(False) |
|
|
| |
| for name, p in model.named_parameters(): |
| |
| if name.startswith("v_head"): |
| p.requires_grad = True |
| |
| elif "lora_" in name: |
| p.requires_grad = True |
| |
| else: |
| p.requires_grad = False |
|
|
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| total = sum(p.numel() for p in model.parameters()) |
| print(f"Trainable params: {trainable}/{total} ({100*trainable/total:.2f}%)") |
|
|
| model.config.use_cache = False |
| ref_model.config.use_cache = False |
|
|
| if tokenizer.pad_token_id is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
| |
| |
| |
| print("Loading Spider subset...") |
| random.seed(0) |
|
|
| |
| TRAIN_DBS = [ |
| "flight_1", |
| "student_assessment", |
| "store_1", |
| "bike_1", |
| "book_2", |
| "chinook_1", |
| ] |
|
|
| dataset = load_dataset("spider", split="train") |
| _TRAIN_DBS_SET = set(TRAIN_DBS) |
| dataset = dataset.filter(lambda x: x["db_id"] in _TRAIN_DBS_SET) |
| dataset = dataset.select(range(min(800, len(dataset)))) |
|
|
| print("Using RLHF DBs:", TRAIN_DBS) |
| print("Filtered size:", len(dataset)) |
|
|
| total_steps = ROLLOUTS_PER_EPOCH |
|
|
| |
| |
| |
| def get_db_path(db_id): |
| return os.path.join(DB_ROOT, db_id, f"{db_id}.sqlite") |
|
|
|
|
| def get_db_schema(db_path): |
| schema_text = "" |
| try: |
| conn = sqlite3.connect(db_path) |
| cursor = conn.cursor() |
|
|
| tables = cursor.execute( |
| "SELECT name FROM sqlite_master WHERE type='table';" |
| ).fetchall() |
|
|
| for table in tables: |
| table_name = table[0] |
| columns = cursor.execute(f"PRAGMA table_info({table_name});").fetchall() |
| col_names = [col[1] for col in columns] |
| schema_text += f"{table_name}({', '.join(col_names)}) " |
|
|
| conn.close() |
| except: |
| pass |
|
|
| return schema_text |
|
|
|
|
| |
| |
| |
| PREFIX = "translate English to SQL:" |
|
|
|
|
| def trim_schema(schema: str, max_chars: int = 1200) -> str: |
| if schema is None: |
| return "" |
| schema = str(schema) |
| if len(schema) <= max_chars: |
| return schema |
| return schema[:max_chars] |
| def build_prompt(question: str, schema: str, use_schema: bool) -> str: |
| if not use_schema: |
| return f"{PREFIX}\n\nQuestion:\n{question}\n\nSQL:" |
|
|
| schema = trim_schema(schema, max_chars=MAX_SCHEMA_CHARS) |
| return f"{PREFIX}\n\nSchema:\n{schema}\n\nQuestion:\n{question}\n\nSQL:" |
|
|
| def encode_prompt(question, schema, use_schema): |
| |
| if not use_schema: |
| prompt = build_prompt(question, schema, use_schema=False) |
| return tokenizer(prompt, return_tensors="pt", truncation=True).input_ids[0].to(device) |
|
|
| schema = trim_schema(schema, max_chars=MAX_SCHEMA_CHARS) |
| prefix_schema = f"{PREFIX}\n\nSchema:\n" |
| mid = "\n\nQuestion:\n" |
| suffix = f"{question}\n\nSQL:" |
|
|
| prefix_ids = tokenizer.encode(prefix_schema, add_special_tokens=False) |
| schema_ids = tokenizer.encode(schema, add_special_tokens=False) |
| mid_ids = tokenizer.encode(mid, add_special_tokens=False) |
| suffix_ids = tokenizer.encode(suffix, add_special_tokens=False) |
|
|
| max_len = getattr(tokenizer, "model_max_length", 512) |
| eos_id = tokenizer.eos_token_id |
| max_without_eos = max_len - (1 if eos_id is not None else 0) |
|
|
| |
| fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids) |
| if fixed_len > max_without_eos: |
| |
| keep = max(0, max_without_eos - (len(prefix_ids) + len(mid_ids))) |
| suffix_ids = suffix_ids[:keep] |
| fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids) |
|
|
| remaining_for_schema = max_without_eos - fixed_len |
| if remaining_for_schema < 0: |
| remaining_for_schema = 0 |
| schema_ids = schema_ids[:remaining_for_schema] |
|
|
| ids = prefix_ids + schema_ids + mid_ids + suffix_ids |
| ids = ids[:max_without_eos] |
| if eos_id is not None: |
| ids = ids + [eos_id] |
|
|
| return torch.tensor(ids, dtype=torch.long).to(device) |
|
|
|
|
| |
| |
| |
| SQL_KEYWORDS = [ |
| "select", "from", "where", "join", "inner", "left", "right", |
| "full", "outer", "on", "group", "by", "order", "having", |
| "limit", "distinct", "as", "and", "or", "not", "in", "is", |
| "null", "like", "between", "asc", "desc", "union", |
| "intersect", "except", |
| ] |
|
|
| SQL_OPERATORS = ["*", ",", ".", "(", ")", "=", "<", ">", "!", "+", "-", "/", "%", "_"] |
|
|
|
|
| def _piece_token_str(tok: str) -> str: |
| |
| return tok.lstrip("▁") |
|
|
|
|
| def _precompute_always_allowed_token_ids(): |
| vocab_size = len(tokenizer) |
| allowed = set() |
|
|
| |
| for tid in [tokenizer.pad_token_id, tokenizer.eos_token_id, tokenizer.unk_token_id]: |
| if tid is not None and tid >= 0: |
| allowed.add(int(tid)) |
|
|
| |
| for s in [" ", "\n", "\t"]: |
| allowed.update(tokenizer.encode(s, add_special_tokens=False)) |
|
|
| |
| op_chars = set("".join(SQL_OPERATORS)) |
| for tid in range(vocab_size): |
| tok = tokenizer.convert_ids_to_tokens(tid) |
| if not isinstance(tok, str) or not tok: |
| continue |
| piece = _piece_token_str(tok) |
| if not piece: |
| continue |
| if all((ch in op_chars) for ch in piece): |
| allowed.add(tid) |
| continue |
| if piece.isdigit(): |
| allowed.add(tid) |
| continue |
| |
| if all(ch.isdigit() for ch in piece): |
| allowed.add(tid) |
|
|
| |
| for kw in SQL_KEYWORDS: |
| for variant in {kw, kw.upper(), kw.capitalize()}: |
| allowed.update(tokenizer.encode(" " + variant, add_special_tokens=False)) |
| allowed.update(tokenizer.encode(variant, add_special_tokens=False)) |
|
|
| return allowed |
|
|
|
|
| ALWAYS_ALLOWED_TOKEN_IDS = _precompute_always_allowed_token_ids() |
|
|
|
|
| def _schema_allowed_token_ids(table_names, column_names): |
| allowed = set(ALWAYS_ALLOWED_TOKEN_IDS) |
|
|
| def _add_identifier(name: str): |
| if not name: |
| return |
| |
| variants = {name, name.lower(), name.upper()} |
| parts = re.split(r"[_\s]+", name) |
| variants.update({p for p in parts if p}) |
| for v in variants: |
| allowed.update(tokenizer.encode(" " + v, add_special_tokens=False)) |
| allowed.update(tokenizer.encode(v, add_special_tokens=False)) |
|
|
| for t in table_names: |
| _add_identifier(t) |
| for c in column_names: |
| _add_identifier(c) |
|
|
| return allowed |
|
|
|
|
| class SQLVocabularyLogitsProcessor(LogitsProcessor): |
| def __init__(self, allowed_token_ids): |
| self.allowed_token_ids = {int(i) for i in allowed_token_ids if int(i) >= 0} |
| self._bias = None |
| self._bias_vocab_size = None |
|
|
| def _get_bias(self, scores: torch.Tensor) -> torch.Tensor: |
| vocab_size = int(scores.shape[-1]) |
| if ( |
| self._bias is None |
| or self._bias.device != scores.device |
| or self._bias.dtype != scores.dtype |
| or self._bias_vocab_size != vocab_size |
| ): |
| bias = torch.full((vocab_size,), float("-inf"), device=scores.device, dtype=scores.dtype) |
| for tid in self.allowed_token_ids: |
| if tid < vocab_size: |
| bias[tid] = 0.0 |
| self._bias = bias |
| self._bias_vocab_size = vocab_size |
| return self._bias |
|
|
| def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: |
| return scores + self._get_bias(scores) |
|
|
|
|
| _DB_VOCAB_CACHE = {} |
|
|
|
|
| def get_db_tables_columns(db_path: str): |
| if db_path in _DB_VOCAB_CACHE: |
| return _DB_VOCAB_CACHE[db_path] |
| tables, cols = [], [] |
| try: |
| conn = sqlite3.connect(db_path) |
| cur = conn.cursor() |
| for (tname,) in cur.execute( |
| "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';" |
| ).fetchall(): |
| if not tname: |
| continue |
| tables.append(tname) |
| try: |
| for row in cur.execute(f'PRAGMA table_info("{tname}")').fetchall(): |
| if row and isinstance(row[1], str): |
| cols.append(row[1]) |
| except Exception: |
| continue |
| conn.close() |
| except Exception: |
| pass |
| _DB_VOCAB_CACHE[db_path] = (tables, cols) |
| return tables, cols |
|
|
|
|
| |
| |
| |
| ppo_config = PPOConfig( |
| learning_rate=2e-5, |
| batch_size=8, |
| mini_batch_size=2, |
| gradient_accumulation_steps=2, |
| ppo_epochs=1, |
|
|
| |
| init_kl_coef=0.05, |
| target_kl=0.15, |
| adap_kl_ctrl=True, |
|
|
| |
| cliprange=0.25, |
| cliprange_value=0.25, |
| whiten_rewards=True, |
| kl_penalty="kl", |
| max_grad_norm=1.0, |
| ) |
| trainer = PPOTrainer( |
| config=ppo_config, |
| model=model, |
| ref_model=ref_model, |
| tokenizer=tokenizer, |
| ) |
| optimizer = trainer.optimizer |
|
|
| |
| try: |
| model.device = torch.device(device) |
| except Exception: |
| pass |
|
|
|
|
| |
| |
| |
| generation_kwargs = dict( |
| max_new_tokens=64, |
|
|
| do_sample=True, |
| temperature=0.9, |
| top_p=0.95, |
| top_k=100, |
|
|
| repetition_penalty=1.1, |
| no_repeat_ngram_size=3, |
|
|
| num_beams=1, |
| pad_token_id=tokenizer.pad_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| ) |
| |
| |
| |
| print("Starting RL training 🚀") |
|
|
| query_buffer, response_buffer, reward_buffer, gold_buffer = [], [], [], [] |
| query_text_buffer = [] |
|
|
| best_reward = -999999 |
| best_epoch = -1 |
|
|
| def _is_parsable_sql(sql: str) -> bool: |
| s = (sql or "").strip() |
| if not s: |
| return False |
| up = s.upper() |
| if "SELECT" not in up or "FROM" not in up: |
| return False |
| if sqlparse is None: |
| return True |
| try: |
| return bool(sqlparse.parse(s)) |
| except Exception: |
| return False |
|
|
|
|
| def _pad_2d(seqs, pad_id: int): |
| max_len = max(int(s.numel()) for s in seqs) |
| out = torch.full((len(seqs), max_len), int(pad_id), dtype=torch.long, device=device) |
| attn = torch.zeros((len(seqs), max_len), dtype=torch.long, device=device) |
| for i, s in enumerate(seqs): |
| n = int(s.numel()) |
| out[i, :n] = s.to(device) |
| attn[i, :n] = 1 |
| return out, attn |
|
|
|
|
| def _shift_right(labels: torch.Tensor, start_id: int) -> torch.Tensor: |
| dec = labels.clone() |
| dec[:, 1:] = labels[:, :-1] |
| dec[:, 0] = int(start_id) |
| return dec |
|
|
|
|
| def safe_get_kl(stats): |
| if not isinstance(stats, dict): |
| return None |
| for k in stats.keys(): |
| if "kl" in str(k).lower(): |
| v = stats[k] |
| try: |
| return float(v.item() if hasattr(v, "item") else v) |
| except Exception: |
| return None |
| return None |
|
|
| def supervised_anchor_step(model, tokenizer, queries, gold_sqls, weight=0.05): |
| model.train() |
| total_loss = 0.0 |
|
|
| for q, gold in zip(queries, gold_sqls): |
|
|
| enc = tokenizer(q, return_tensors="pt", truncation=True).to(model.device) |
| dec = tokenizer(text_target=gold, return_tensors="pt", truncation=True) |
|
|
| labels = dec.input_ids.to(model.device) |
|
|
| |
| decoder_input_ids = labels[:, :-1].contiguous() |
| target_ids = labels[:, 1:].contiguous() |
|
|
| outputs = model( |
| input_ids=enc.input_ids, |
| attention_mask=enc.attention_mask, |
| decoder_input_ids=decoder_input_ids, |
| ) |
|
|
| logits = outputs[0] |
|
|
| vocab_size = logits.size(-1) |
| loss = F.cross_entropy( |
| logits.view(-1, vocab_size), |
| target_ids.view(-1), |
| ignore_index=tokenizer.pad_token_id, |
| ) |
|
|
| (loss * weight).backward() |
| total_loss += loss.item() |
|
|
| return total_loss |
|
|
|
|
| @torch.no_grad() |
| def _estimate_policy_entropy(query_tensors, response_tensors) -> torch.Tensor: |
| """ |
| Returns per-sample average token entropy of the policy on the sampled response tokens. |
| Used as a small bonus to reduce repetition collapse. |
| """ |
| pad_id = int(tokenizer.pad_token_id) |
| enc_ids, enc_attn = _pad_2d(query_tensors, pad_id) |
| dec_ids, dec_attn = _pad_2d(response_tensors, pad_id) |
|
|
| start_id = int(getattr(model.pretrained_model.config, "decoder_start_token_id", pad_id)) |
| dec_inp = _shift_right(dec_ids, start_id) |
|
|
| out = model.pretrained_model( |
| input_ids=enc_ids, |
| attention_mask=enc_attn, |
| decoder_input_ids=dec_inp, |
| use_cache=False, |
| ) |
| logp = torch.log_softmax(out.logits, dim=-1) |
| p = torch.exp(logp) |
| ent = -(p * logp).sum(dim=-1) |
| |
| denom = dec_attn.sum(dim=-1).clamp_min(1) |
| return (ent * dec_attn).sum(dim=-1) / denom |
|
|
|
|
| def _repeat_penalty(response_tensor: torch.Tensor) -> float: |
| """ |
| Penalize repetition to avoid 'SELECT SELECT SELECT' collapse. |
| Simple heuristic: consecutive duplicate token ratio + low-unique-token ratio. |
| """ |
| ids = response_tensor.detach().tolist() |
| n = len(ids) |
| if n <= 1: |
| return 0.0 |
| consec_dup = 0 |
| for i in range(1, n): |
| if ids[i] == ids[i - 1]: |
| consec_dup += 1 |
| unique_ratio = len(set(ids)) / n |
| consec_ratio = consec_dup / (n - 1) |
| |
| return float(0.5 * consec_ratio + 0.5 * (1.0 - unique_ratio)) |
|
|
|
|
| def _supervised_anchor_step(query_tensors, gold_sql_texts, weight: float = 0.05) -> None: |
| """ |
| Small teacher-forcing step on gold SQL to anchor grammar during PPO. |
| Runs only if PPOTrainer exposes (accelerator, optimizer). |
| """ |
| if not gold_sql_texts: |
| return |
| accelerator = getattr(trainer, "accelerator", None) |
| optimizer = getattr(trainer, "optimizer", None) |
| if accelerator is None or optimizer is None: |
| return |
|
|
| pad_id = int(tokenizer.pad_token_id) |
| enc_ids, enc_attn = _pad_2d(query_tensors, pad_id) |
|
|
| |
| gold_ids = [] |
| for s in gold_sql_texts: |
| g = (s or "").strip() |
| if not g: |
| g = "SELECT 1" |
| ids = tokenizer.encode(g, add_special_tokens=False)[:256] |
| if tokenizer.eos_token_id is not None: |
| ids = ids + [int(tokenizer.eos_token_id)] |
| gold_ids.append(torch.tensor(ids, dtype=torch.long)) |
|
|
| dec_ids, dec_attn = _pad_2d(gold_ids, pad_id) |
| labels = dec_ids.clone() |
| labels[dec_attn == 0] = -100 |
|
|
| |
| out = model.pretrained_model( |
| input_ids=enc_ids, |
| attention_mask=enc_attn, |
| labels=labels, |
| use_cache=False, |
| ) |
| loss = out.loss * float(weight) |
|
|
| optimizer.zero_grad(set_to_none=True) if hasattr(optimizer, "zero_grad") else None |
| accelerator.backward(loss) |
| optimizer.step() |
|
|
|
|
| def _curriculum_allows(gold_sql: str, epoch_num: int) -> bool: |
| gold_up = (gold_sql or "").upper() |
| has_join = "JOIN" in gold_up |
| has_set_op = any(op in gold_up for op in ["UNION", "INTERSECT", "EXCEPT"]) |
| tables = extract_tables(gold_sql) |
| single_table = len(tables) <= 1 and (not has_join) |
|
|
| |
| if epoch_num == 1: |
| return single_table and (not has_set_op) |
| |
| if epoch_num == 2: |
| return (single_table or has_join) and (not has_set_op) |
| |
| return True |
|
|
|
|
| for epoch in range(1, NUM_EPOCHS + 1): |
|
|
| use_schema_this_epoch = USE_SCHEMA and (epoch > SCHEMA_WARMUP_EPOCHS) |
|
|
| epoch_reward_sum = 0 |
| negative_rewards = 0 |
| partial_rewards = 0 |
| correct_rewards = 0 |
|
|
| total_considered = 0 |
| valid_sql_count = 0 |
| exec_correct_count = 0 |
| table_overlap_sum = 0.0 |
| column_overlap_sum = 0.0 |
| kl_values = [] |
|
|
| for step in range(1, total_steps + 1): |
|
|
| example = dataset[random.randrange(len(dataset))] |
|
|
| question = example["question"] |
| gold_sql = example["query"] |
| db_id = example["db_id"] |
| db_path = get_db_path(db_id) |
|
|
| |
|
|
| schema = get_db_schema(db_path) |
| question_text = build_prompt(question, schema, use_schema_this_epoch) |
| query_tensor = encode_prompt(question, schema, use_schema_this_epoch) |
|
|
| |
| table_names, column_names = get_db_tables_columns(db_path) |
| allowed_ids = _schema_allowed_token_ids(table_names, column_names) |
| logits_processor = LogitsProcessorList([SQLVocabularyLogitsProcessor(allowed_ids)]) |
|
|
| response = trainer.generate([query_tensor], logits_processor=logits_processor, **generation_kwargs)[0] |
| response_tensor = response.squeeze(0)[:MAX_OUTPUT_TOKENS] |
|
|
| pred_sql = tokenizer.decode(response_tensor.cpu(), skip_special_tokens=True) |
|
|
| total_considered += 1 |
|
|
| |
| if not _is_parsable_sql(pred_sql): |
| negative_rewards += 1 |
| continue |
|
|
| |
| if int(response_tensor.numel()) < 6: |
| negative_rewards += 1 |
| continue |
|
|
| |
| reward_value = execution_reward(pred_sql, db_path, gold_sql) |
|
|
| |
| if reward_value is None: |
| if step % 100 == 0: |
| ratio = valid_sql_count / max(total_considered, 1) |
| print(f"\nLearning ratio: {valid_sql_count}/{total_considered} ({ratio:.3f})") |
| if ratio < 0.15: |
| print("MODEL COLLAPSING") |
| continue |
|
|
| |
| reward_value = float(max(-1.0, min(1.0, reward_value))) |
| |
| reward_value = float(max(-1.0, min(1.0, reward_value - 0.2 * _repeat_penalty(response_tensor)))) |
| |
| reward_tensor = torch.tensor(reward_value, dtype=torch.float32) |
|
|
| epoch_reward_sum += reward_value |
|
|
| |
| |
| valid_sql_count += 1 |
|
|
| pred_tables = extract_tables(pred_sql) |
| gold_tables = extract_tables(gold_sql) |
| pred_cols = extract_columns(pred_sql) |
| gold_cols = extract_columns(gold_sql) |
|
|
| if len(gold_tables) > 0: |
| table_overlap_sum += len(pred_tables & gold_tables) / max(len(gold_tables), 1) |
| if len(gold_cols) > 0: |
| column_overlap_sum += len(pred_cols & gold_cols) / max(len(gold_cols), 1) |
|
|
| |
| if reward_value >= 1.0: |
| exec_correct_count += 1 |
|
|
| if reward_value <= -1.0: |
| negative_rewards += 1 |
| elif reward_value >= 1.0: |
| correct_rewards += 1 |
| else: |
| partial_rewards += 1 |
|
|
| |
| |
| |
| if abs(reward_value) < 0.1: |
| continue |
|
|
| query_buffer.append(query_tensor) |
| response_buffer.append(response_tensor) |
| reward_buffer.append(reward_tensor) |
| gold_buffer.append(gold_sql) |
| query_text_buffer.append(question_text) |
|
|
| |
| if len(query_buffer) == ppo_config.batch_size: |
| |
| reward_buffer = [r.to(device) for r in reward_buffer] |
|
|
| |
| stats = trainer.step(query_buffer, response_buffer, reward_buffer) |
|
|
| |
| kl = safe_get_kl(stats) |
| if kl is not None: |
| kl_values.append(kl) |
|
|
| |
| supervised_anchor_step(model, tokenizer, query_text_buffer, gold_buffer, weight=0.05) |
| optimizer.step() |
| optimizer.zero_grad() |
|
|
| |
| query_buffer, response_buffer, reward_buffer, gold_buffer = [], [], [], [] |
| query_text_buffer = [] |
|
|
| |
| if step % 100 == 0: |
| ratio = valid_sql_count / max(total_considered, 1) |
| print(f"\nLearning ratio: {valid_sql_count}/{total_considered} ({ratio:.3f})") |
| if ratio < 0.15: |
| print("MODEL COLLAPSING") |
| |
| try: |
| if hasattr(trainer, "kl_ctl") and hasattr(trainer.kl_ctl, "value"): |
| trainer.kl_ctl.value *= 1.5 |
| print(f"Increasing KL coef -> {trainer.kl_ctl.value:.4f}") |
| except Exception: |
| pass |
|
|
| |
| if step % LOG_EVERY == 0: |
| avg_reward = epoch_reward_sum / step |
| print("\n---------------------------") |
| print(f"Epoch {epoch}/{NUM_EPOCHS} | Step {step}/{total_steps} | Avg Reward {avg_reward:.3f}") |
| print("DB:", db_id) |
| print("Q:", question) |
| print("SQL:", pred_sql) |
| print("Reward:", reward_value) |
|
|
| |
| print(f"\nEpoch {epoch} stats:") |
| print("negative:", negative_rewards) |
| print("partial:", partial_rewards) |
| print("correct:", correct_rewards) |
|
|
| denom = max(total_considered, 1) |
| print("\nEpoch metrics:") |
| print(f"execution_accuracy: {exec_correct_count/denom:.3f}") |
| print(f"valid_sql_rate: {valid_sql_count/denom:.3f}") |
| print(f"table_match_rate: {table_overlap_sum/denom:.3f}") |
| print(f"column_match_rate: {column_overlap_sum/denom:.3f}") |
| print(f"avg_reward: {epoch_reward_sum/max(denom,1):.3f}") |
| if kl_values: |
| avg_kl = sum(kl_values) / max(len(kl_values), 1) |
| print(f"avg_kl: {avg_kl:.3f}") |
| if avg_kl < -8: |
| print("\nKL collapse guard triggered (avg_kl < -8). Stopping early.") |
| break |
|
|
| |
|
|
| |
| epoch_avg_reward = epoch_reward_sum / max(denom, 1) |
| if epoch_avg_reward > best_reward: |
| best_reward = epoch_avg_reward |
| best_epoch = epoch |
|
|
| print(f"\nNew best model at epoch {epoch} with reward {best_reward:.4f}") |
|
|
| |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| trainer.model.save_pretrained(output_dir) |
| tokenizer.save_pretrained(output_dir) |
|
|
|
|
| print(f"\nTraining finished.") |
| print(f"Best epoch: {best_epoch}") |
| print(f"Best reward: {best_reward:.4f}") |
| print(f"Best model saved at: {output_dir}") |