|
|
| """
|
| Script to calculate accuracy of a trained GRPO model on arithmetic countdown problems.
|
|
|
| This script loads a CSV file with problem data, performs inference using the trained model,
|
| and calculates the accuracy by comparing predicted answers with correct answers.
|
| """
|
| import sys, os
|
| current_dir = os.path.dirname(os.path.abspath(__file__))
|
| project_root = os.path.abspath(os.path.join(current_dir, "../.."))
|
| sys.path.append(project_root)
|
|
|
| import argparse
|
| import logging
|
| import re
|
| import sys
|
| from pathlib import Path
|
|
|
| import pandas as pd
|
| import torch
|
| from tqdm import tqdm
|
|
|
|
|
| sys.path.append(str(Path(__file__).parent.parent))
|
|
|
| from src.utils.inference import GRPOModelInference
|
|
|
|
|
| logging.basicConfig(
|
| level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
| )
|
| logger = logging.getLogger("calculate_accuracy")
|
|
|
|
|
| def load_csv_data(csv_path: str) -> pd.DataFrame:
|
| """
|
| Load CSV data with the expected format.
|
|
|
| Expected columns: id, problem_description, correct_answer, num1, num2, num3, num4
|
|
|
| Args:
|
| csv_path: Path to the CSV file
|
|
|
| Returns:
|
| DataFrame with the loaded data
|
| """
|
| df = pd.read_csv(csv_path)
|
|
|
|
|
| required_columns = [
|
| "id",
|
| "problem_description",
|
| "correct_answer",
|
| "num1",
|
| "num2",
|
| "num3",
|
| "num4",
|
| ]
|
| missing_columns = [col for col in required_columns if col not in df.columns]
|
|
|
| if missing_columns:
|
| raise ValueError(f"Missing required columns: {missing_columns}")
|
|
|
| logger.info(f"Loaded {len(df)} problems from {csv_path}")
|
| return df
|
|
|
|
|
| def safe_eval_expression(expression: str) -> tuple[float | None, bool]:
|
| """
|
| Safely evaluate an arithmetic expression.
|
|
|
| Args:
|
| expression: The arithmetic expression to evaluate
|
|
|
| Returns:
|
| Tuple of (result, is_valid)
|
| """
|
| if not expression or not expression.strip():
|
| return None, False
|
|
|
|
|
| normalized = expression.replace("x", "*").replace("X", "*")
|
|
|
|
|
| allowed_chars = set("0123456789+-*/.() ")
|
| if not all(c in allowed_chars for c in normalized):
|
| return None, False
|
|
|
| try:
|
| result = eval(normalized)
|
| return result, True
|
| except (SyntaxError, ValueError, ZeroDivisionError, NameError):
|
| return None, False
|
|
|
|
|
| def check_numbers_usage(expression: str, required_numbers: list[int]) -> bool:
|
| """
|
| Check if the expression uses exactly the required numbers.
|
|
|
| Args:
|
| expression: The arithmetic expression to check
|
| required_numbers: List of numbers that should be used exactly once each
|
|
|
| Returns:
|
| True if expression uses all required numbers exactly once, False otherwise
|
| """
|
| if not expression or not expression.strip():
|
| return False
|
|
|
|
|
| numbers_in_expression = re.findall(r"\b\d+\b", expression)
|
|
|
|
|
| try:
|
| numbers_in_expression = [int(num) for num in numbers_in_expression]
|
| except ValueError:
|
| return False
|
|
|
|
|
| required_sorted = sorted(required_numbers)
|
| found_sorted = sorted(numbers_in_expression)
|
|
|
| return required_sorted == found_sorted
|
|
|
|
|
| def evaluate_prediction(
|
| predicted_answer: str, correct_answer: int, nums: list[int]
|
| ) -> dict:
|
| """
|
| Evaluate a single prediction against the correct answer.
|
|
|
| Args:
|
| predicted_answer: The model's predicted arithmetic expression
|
| correct_answer: The correct integer result
|
| nums: List of four numbers used in the problem
|
|
|
| Returns:
|
| Dictionary with evaluation results
|
| """
|
| result = {
|
| "predicted_answer": predicted_answer,
|
| "correct_answer": correct_answer,
|
| "is_correct": False,
|
| "is_valid_format": False,
|
| "uses_all_numbers": False,
|
| "predicted_result": None,
|
| "correct_result": correct_answer,
|
| }
|
|
|
|
|
| predicted_result, is_valid_predicted = safe_eval_expression(predicted_answer)
|
| result["predicted_result"] = predicted_result
|
| result["is_valid_format"] = is_valid_predicted
|
|
|
|
|
| uses_all_numbers = check_numbers_usage(predicted_answer, nums)
|
| result["uses_all_numbers"] = uses_all_numbers
|
|
|
|
|
| logger.info(
|
| f"Answered: {predicted_answer} - Predicted result: {predicted_result} - Correct result: {correct_answer} - Uses all numbers: {uses_all_numbers}"
|
| )
|
|
|
|
|
| if is_valid_predicted and predicted_result is not None and uses_all_numbers:
|
| result["is_correct"] = abs(predicted_result - correct_answer) < 1e-6
|
|
|
| return result
|
|
|
|
|
| def calculate_accuracy(
|
| csv_path: str,
|
| sft_model_path: str | None,
|
| grpo_model_path: str | None,
|
| base_model_id: str = "Qwen/Qwen2.5-Math-1.5B",
|
| device: str = "auto",
|
| dtype: torch.dtype = torch.float16,
|
| max_new_tokens: int = 4096,
|
| temperature: float = 0.7,
|
| max_samples: int | None = None,
|
| output_path: str | None = None,
|
| ) -> dict:
|
| """
|
| Calculate accuracy of the model on the given dataset.
|
|
|
| Args:
|
| csv_path: Path to the CSV file with test data
|
| sft_model_path: Path to the SFT model
|
| grpo_model_path: Path to the GRPO model
|
| base_model_id: Base model identifier
|
| device: Device to run inference on
|
| dtype: Data type for the model
|
| max_new_tokens: Maximum tokens to generate
|
| temperature: Sampling temperature
|
| max_samples: Maximum number of samples to evaluate (None for all)
|
| output_path: Path to save detailed results (optional)
|
|
|
| Returns:
|
| Dictionary with accuracy metrics
|
| """
|
|
|
| df = load_csv_data(csv_path)
|
|
|
| if max_samples is not None:
|
| df = df.head(max_samples)
|
| logger.info(f"Limiting evaluation to {max_samples} samples")
|
|
|
|
|
| logger.info("Loading model...")
|
| model_inference = GRPOModelInference(
|
| sft_model_path=sft_model_path,
|
| grpo_model_path=grpo_model_path,
|
| base_model_id=base_model_id,
|
| device=device,
|
| dtype=dtype,
|
| )
|
|
|
|
|
| results = []
|
| correct_predictions = 0
|
| valid_format_predictions = 0
|
| uses_all_numbers_predictions = 0
|
|
|
| logger.info("Starting evaluation...")
|
| pbar = tqdm(df.iterrows(), total=len(df), desc="Evaluating")
|
|
|
| for idx, (_, row) in enumerate(pbar):
|
|
|
| response, extracted_answer, _ = model_inference.solve_problem(
|
| problem_description=row["problem_description"],
|
| max_new_tokens=max_new_tokens,
|
| temperature=temperature,
|
| )
|
|
|
|
|
| nums = [row["num1"], row["num2"], row["num3"], row["num4"]]
|
| evaluation = evaluate_prediction(
|
| predicted_answer=extracted_answer,
|
| correct_answer=row["correct_answer"],
|
| nums=nums,
|
| )
|
|
|
|
|
| evaluation.update(
|
| {
|
| "id": row["id"],
|
| "problem_description": row["problem_description"],
|
| "full_response": response,
|
| "nums": nums,
|
| }
|
| )
|
|
|
| results.append(evaluation)
|
|
|
|
|
| if evaluation["is_correct"]:
|
| correct_predictions += 1
|
| if evaluation["is_valid_format"]:
|
| valid_format_predictions += 1
|
| if evaluation["uses_all_numbers"]:
|
| uses_all_numbers_predictions += 1
|
|
|
|
|
| current_accuracy = correct_predictions / (idx + 1) if (idx + 1) > 0 else 0
|
| current_valid_rate = (
|
| valid_format_predictions / (idx + 1) if (idx + 1) > 0 else 0
|
| )
|
| pbar.set_postfix(
|
| {
|
| "Acc": f"{current_accuracy:.3f}",
|
| "Valid": f"{current_valid_rate:.3f}",
|
| "Correct": f"{correct_predictions}/{idx + 1}",
|
| }
|
| )
|
|
|
|
|
| total_samples = len(results)
|
| accuracy = correct_predictions / total_samples if total_samples > 0 else 0
|
| valid_format_rate = (
|
| valid_format_predictions / total_samples if total_samples > 0 else 0
|
| )
|
| uses_all_numbers_rate = (
|
| uses_all_numbers_predictions / total_samples if total_samples > 0 else 0
|
| )
|
|
|
| metrics = {
|
| "total_samples": total_samples,
|
| "correct_predictions": correct_predictions,
|
| "valid_format_predictions": valid_format_predictions,
|
| "uses_all_numbers_predictions": uses_all_numbers_predictions,
|
| "accuracy": accuracy,
|
| "valid_format_rate": valid_format_rate,
|
| "uses_all_numbers_rate": uses_all_numbers_rate,
|
| }
|
|
|
|
|
| logger.info("Evaluation completed!")
|
| logger.info(f"Total samples: {total_samples}")
|
| logger.info(f"Correct predictions: {correct_predictions}")
|
| logger.info(f"Valid format predictions: {valid_format_predictions}")
|
| logger.info(f"Uses all numbers predictions: {uses_all_numbers_predictions}")
|
| logger.info(f"Accuracy: {accuracy:.4f} ({accuracy * 100:.2f}%)")
|
| logger.info(
|
| f"Valid format rate: {valid_format_rate:.4f} ({valid_format_rate * 100:.2f}%)"
|
| )
|
| logger.info(
|
| f"Uses all numbers rate: {uses_all_numbers_rate:.4f} ({uses_all_numbers_rate * 100:.2f}%)"
|
| )
|
|
|
|
|
| if output_path:
|
| results_df = pd.DataFrame(results)
|
| results_df.to_csv(output_path, index=False)
|
| logger.info(f"Detailed results saved to {output_path}")
|
|
|
| return metrics
|
|
|
|
|
| def main():
|
| """Main function to run the accuracy calculation script."""
|
| parser = argparse.ArgumentParser(
|
| description="Calculate accuracy of GRPO model on arithmetic countdown problems"
|
| )
|
|
|
| parser.add_argument(
|
| "--csv_path",
|
| type=str,
|
| required=True,
|
| default="data/grpo/test.csv",
|
| help="Path to CSV file with test data",
|
| )
|
| parser.add_argument(
|
| "--sft_model_path",
|
| type=str,
|
| default="models/sft/",
|
| help="Path to SFT model directory",
|
| )
|
| parser.add_argument(
|
| "--grpo_model_path",
|
| type=str,
|
| default="models/grpo/",
|
| help="Path to GRPO model directory",
|
| )
|
| parser.add_argument(
|
| "--base_model_id",
|
| type=str,
|
| default="Qwen/Qwen2.5-Math-1.5B",
|
| help="Base model identifier",
|
| )
|
| parser.add_argument(
|
| "--device", type=str, default="auto", help="Device to run inference on"
|
| )
|
| parser.add_argument(
|
| "--max_new_tokens", type=int, default=4096, help="Maximum tokens to generate"
|
| )
|
| parser.add_argument(
|
| "--temperature", type=float, default=1.0, help="Sampling temperature"
|
| )
|
| parser.add_argument(
|
| "--max_samples",
|
| type=int,
|
| default=None,
|
| help="Maximum number of samples to evaluate",
|
| )
|
| parser.add_argument(
|
| "--output_path",
|
| type=str,
|
| default=None,
|
| help="Path to save detailed results CSV",
|
| )
|
| parser.add_argument(
|
| "--no_sft",
|
| action="store_true",
|
| help="Skip loading the SFT model (use only base model)",
|
| )
|
| parser.add_argument(
|
| "--no_grpo",
|
| action="store_true",
|
| help="Skip loading the GRPO model (use only SFT model)",
|
| )
|
|
|
| args = parser.parse_args()
|
|
|
|
|
| dtype = torch.float16
|
|
|
|
|
| metrics = calculate_accuracy(
|
| csv_path=args.csv_path,
|
| sft_model_path=args.sft_model_path if not args.no_sft else None,
|
| grpo_model_path=args.grpo_model_path if not args.no_grpo else None,
|
| base_model_id=args.base_model_id,
|
| device=args.device,
|
| dtype=dtype,
|
| max_new_tokens=args.max_new_tokens,
|
| temperature=args.temperature,
|
| max_samples=args.max_samples,
|
| output_path=args.output_path,
|
| )
|
|
|
| print("\n" + "=" * 50)
|
| print("FINAL RESULTS")
|
| print("=" * 50)
|
| print(f"Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy'] * 100:.2f}%)")
|
| print(
|
| f"Valid Format Rate: {metrics['valid_format_rate']:.4f} ({metrics['valid_format_rate'] * 100:.2f}%)"
|
| )
|
| print(
|
| f"Uses All Numbers Rate: {metrics['uses_all_numbers_rate']:.4f} ({metrics['uses_all_numbers_rate'] * 100:.2f}%)"
|
| )
|
| print(
|
| f"Correct Predictions: {metrics['correct_predictions']}/{metrics['total_samples']}"
|
| )
|
| print(
|
| f"Valid Format Predictions: {metrics['valid_format_predictions']}/{metrics['total_samples']}"
|
| )
|
| print(
|
| f"Uses All Numbers Predictions: {metrics['uses_all_numbers_predictions']}/{metrics['total_samples']}"
|
| )
|
| print("=" * 50)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|