| | import pandas as pd |
| | import argparse |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | import torch |
| | import random |
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| | def convert_to_natural_phrase(val): |
| | """Convert underscore-separated tokens to natural phrases.""" |
| | if isinstance(val, str) and "_" in val: |
| | val = val.replace("_", " ") |
| | return val |
| |
|
| |
|
| | def generate_answer(tokenizer, model, question, correct_value, device, mode="mcq"): |
| | """Generate a natural language answer using a text-only LLM. |
| | |
| | mode: "mcq" (default) uses the original MCQ-oriented prompt. |
| | "open_text" uses a direct rewrite prompt for provided question/answer pairs. |
| | """ |
| | correct_value = convert_to_natural_phrase(correct_value) |
| |
|
| | if mode == "open_text": |
| | system_preamble = ( |
| | "You convert (Question, short Answer) into EXACTLY ONE natural English sentence that answers the Question.\n\n" |
| | "HARD RULES:\n" |
| | "- Output exactly ONE sentence. No newlines, no bullet points, no labels, no quotes.\n" |
| | "- Use ONLY the provided Answer content as the factual answer; do not add any new facts.\n" |
| | "- Be concise and direct.\n" |
| | "- Do NOT include any numbers unless the question is a COUNT question.\n" |
| | "- Vary phrasing strongly across items; avoid repeating the same structure.\n\n" |
| | "VARIABILITY REQUIREMENT (IMPORTANT):\n" |
| | "- For all questions, you MUST vary sentence structure.\n" |
| | "- Randomly choose ONE of these patterns each time:\n" |
| | " (A) Start with the sound name (Answer) -> then the relation.\n" |
| | " (B) Start with the relation -> then the sound name (Answer).\n" |
| | " (C) Use an 'it`s...' style clause after the Answer.\n" |
| | " (D) Use a short, natural rephrase with different verbs (e.g., lasts, continues, stands out, comes through).\n" |
| | "- Do not always use 'The sound with the ... is ...' — that pattern should be rare.\n\n" |
| | "TASK HANDLING (infer from the Question):\n" |
| | "- COUNT questions (how many / count / number):\n" |
| | " * If Answer is numeric, write it EITHER as digits (e.g., 10) OR as a word (e.g., ten). Do NOT include both.\n" |
| | "- DURATION questions (longest/shortest):\n" |
| | " * Clearly state longest vs shortest, and use the Answer as the sound name. Do not include any numbers.\n" |
| | "- VOLUME questions (minimum/maximum loudness, quietest/loudest):\n" |
| | " * Match minimum vs maximum loudness and use the Answer as the sound name. No dB values.\n" |
| | "- ORDER questions (first/second/before/after/second-to-last):\n" |
| | " * Match the requested relation and use the Answer as the sound name.\n\n" |
| | "Return only the sentence." |
| | ) |
| |
|
| | user_prompt = ( |
| | f"Question: {question}\n" |
| | f"Answer: {correct_value}\n" |
| | "Rewrite the answer as a single, natural sentence that directly answers the question." |
| | ) |
| | else: |
| | system_preamble = ( |
| | "You are a helpful assistant that converts multiple-choice QA pairs into natural language answers.\n" |
| | "CRITICAL RULES:\n" |
| | "1. Write as a human would naturally speak - vary sentence structure and avoid repetitive patterns\n" |
| | "2. Keep responses concise but natural and affirmative avoiding words like 'might/may' or 'could' - one clear sentence\n" |
| | "3. Do not mention 'among the options/among the following' even if the question mentions it. This natural language statement is supposed to be a direct answer.\n" |
| | "4. Do NOT invent sounds.\n" |
| | "5. Do not reason to answer the question, you're just supposed to provide the correct mcq answer as a natural language answer in a single sentence.\n" |
| | "Return only the natural language answer, nothing else." |
| | ) |
| | user_prompt = ( |
| | f"Now, given the question: '{question}' and the correct answer: '{correct_value}', " |
| | f"write one natural-language answer as you would expect from a human." |
| | ) |
| |
|
| | |
| | messages = [ |
| | {"role": "system", "content": system_preamble}, |
| | {"role": "user", "content": user_prompt}, |
| | ] |
| | inputs = tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=True, |
| | add_generation_prompt=True, |
| | return_tensors="pt", |
| | ).to(device) |
| |
|
| | input_length = inputs.shape[1] |
| |
|
| | with torch.no_grad(): |
| | output = model.generate( |
| | inputs, |
| | max_new_tokens=64, |
| | do_sample=True, |
| | temperature=0.8, |
| | top_p=0.9, |
| | repetition_penalty=1.05, |
| | no_repeat_ngram_size=3, |
| | pad_token_id=tokenizer.eos_token_id, |
| | eos_token_id=tokenizer.eos_token_id, |
| | ) |
| |
|
| | generated_ids = output[0, input_length:] |
| | response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() |
| | print(f"Model response: {response}") |
| | return response |
| |
|
| |
|
| | def detect_csv_format(df): |
| | """ |
| | Detect CSV layout and return column mappings. |
| | Supports: |
| | - original MCQ format |
| | - perturbed MCQ format |
| | - open-text format (question/answer present) |
| | """ |
| | columns = df.columns.tolist() |
| |
|
| | if "correct" in columns and "id" in columns and "audio_path" in columns: |
| | |
| | return { |
| | "id_col": "id", |
| | "audio_path_col": "audio_path", |
| | "answer_col": "correct", |
| | "question_col": "question", |
| | "format_type": "original", |
| | } |
| | if "answer" in columns and "idx" in columns and "new_audio_path" in columns: |
| | |
| | return { |
| | "id_col": "idx", |
| | "audio_path_col": "new_audio_path", |
| | "answer_col": "answer", |
| | "question_col": "question", |
| | "format_type": "perturbed", |
| | } |
| | if "answer" in columns and "question" in columns: |
| | |
| | return { |
| | "id_col": "id" if "id" in columns else None, |
| | "audio_path_col": "audio_path" if "audio_path" in columns else None, |
| | "answer_col": "answer", |
| | "question_col": "question", |
| | "format_type": "open_text", |
| | } |
| |
|
| | raise ValueError(f"Unknown CSV format. Columns found: {columns}") |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser( |
| | description="Convert CSV to NL answers (MCQ or open-text) using meta-llama/Llama-3.1-8B-Instruct" |
| | ) |
| | parser.add_argument("--input", required=True, help="Input CSV file") |
| | parser.add_argument("--output", required=False, help="Output CSV file (defaults to input for in-place append)") |
| | parser.add_argument( |
| | "--mode", |
| | required=True, |
| | choices=["mcq", "open_text"], |
| | help="Conversion mode: mcq -> convert MCQ correct option to natural answer; open_text -> rewrite provided short answer to a natural sentence", |
| | ) |
| | parser.add_argument( |
| | "--task", |
| | required=True, |
| | choices=["count", "duration", "order", "volume"], |
| | help="Task type this CSV belongs to (used for bookkeeping/logging)", |
| | ) |
| |
|
| | |
| | parser.add_argument( |
| | "--one_word_ratio", |
| | type=float, |
| | default=0.2, |
| | help="Fraction of samples to output as just the normalized one-word/phrase answer (no LLM forward pass). Default 0.2", |
| | ) |
| | parser.add_argument( |
| | "--seed", |
| | type=int, |
| | default=123, |
| | help="Random seed for reproducible one_word sampling.", |
| | ) |
| |
|
| | args = parser.parse_args() |
| | random.seed(args.seed) |
| |
|
| | print("Loading meta-llama/Llama-3.1-8B-Instruct tokenizer and model...") |
| | tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", use_fast=False) |
| | model = AutoModelForCausalLM.from_pretrained( |
| | "meta-llama/Llama-3.1-8B-Instruct", |
| | torch_dtype="auto", |
| | device_map="auto", |
| | ) |
| | model.eval() |
| |
|
| | df = pd.read_csv(args.input) |
| |
|
| | |
| | format_info = detect_csv_format(df) |
| | print(f"Detected CSV format: {format_info['format_type']}") |
| |
|
| | |
| | if args.mode == "mcq" and format_info["format_type"] == "open_text": |
| | raise ValueError( |
| | "Requested mode=mcq but input appears to be open_text format. Use --mode open_text or supply an MCQ CSV." |
| | ) |
| | if args.mode == "open_text" and format_info["format_type"] != "open_text": |
| | raise ValueError( |
| | "Requested mode=open_text but input does not appear to be open_text format. Use --mode mcq or supply an open_text CSV." |
| | ) |
| |
|
| | output_path = args.output if args.output else args.input |
| |
|
| | nl_rows = [] |
| | device = model.device |
| |
|
| | for i, row in df.iterrows(): |
| | question = row[format_info["question_col"]] |
| |
|
| | |
| | if format_info["format_type"] == "open_text": |
| | correct_value = row[format_info["answer_col"]] |
| | else: |
| | correct_letter = row[format_info["answer_col"]] |
| | option_map = {"A": "optionA", "B": "optionB", "C": "optionC", "D": "optionD"} |
| | correct_value = row[option_map[correct_letter]] |
| |
|
| | |
| | correct_value = convert_to_natural_phrase(correct_value) |
| |
|
| | print(f"[{i+1}/{len(df)}] Q: {question} | Ans: {correct_value}") |
| |
|
| | |
| | if random.random() < args.one_word_ratio: |
| | nl_answer = correct_value |
| | print(f"Skipped LLM (one_word_ratio). Output: {nl_answer}") |
| | else: |
| | nl_answer = generate_answer( |
| | tokenizer, |
| | model, |
| | question, |
| | correct_value, |
| | device, |
| | mode=("open_text" if format_info["format_type"] == "open_text" else "mcq"), |
| | ) |
| |
|
| | nl_rows.append( |
| | { |
| | "question": question, |
| | "id": row[format_info["id_col"]] if format_info.get("id_col") and format_info["id_col"] in row else None, |
| | "audio_path": row[format_info["audio_path_col"]] |
| | if format_info.get("audio_path_col") |
| | else None, |
| | "original_answer": correct_value, |
| | "open_text_answer": nl_answer, |
| | } |
| | ) |
| |
|
| | |
| | nl_df = pd.DataFrame(nl_rows) |
| | df["open_text_answer"] = nl_df["open_text_answer"] |
| | df.to_csv(output_path, index=False) |
| | print(f"Appended natural language answers to {output_path}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|