Spaces:
Configuration error
Configuration error
| import json | |
| import argparse | |
| import os | |
| import requests | |
| from typing import Any | |
| from commitguard_env.parse_action import parse_action | |
| def run_episode(env_url: str, sample_id: str, model_client: Any = None) -> float: | |
| """ | |
| Runs a full 5-step episode for a single sample. | |
| """ | |
| # 1. Reset | |
| # In a real evaluate, we'd need a reset_to_id endpoint or just loop reset until ID matches. | |
| # For now, we assume reset gives us a random sample and we track it. | |
| r = requests.post(f"{env_url}/reset") | |
| data = r.json() | |
| obs = data["observation"] | |
| total_reward = 0.0 | |
| done = False | |
| step_count = 0 | |
| while not done and step_count < 5: | |
| # Prompt model (Simplified for script) | |
| if model_client: | |
| action_str = model_client.generate(obs['diff'], obs['available_files']) | |
| else: | |
| # Mock: straight to verdict for evaluation baseline | |
| action_str = f"<action><action_type>verdict</action_type><is_vulnerable>true</is_vulnerable></action>" | |
| r = requests.post(f"{env_url}/step", json={"action": action_str}) | |
| res = r.json() | |
| obs = res["observation"] | |
| total_reward = res["reward"] # Environment returns cumulative or step reward? | |
| # In CommitGuard, reward at verdict includes the outcome. | |
| done = res["done"] | |
| step_count += 1 | |
| return total_reward | |
| def evaluate(env_url: str, test_file: str, adapter_path: str = None): | |
| with open(test_file, "r") as f: | |
| test_samples = [json.loads(line) for line in f] | |
| # Loading model if adapter provided | |
| model_client = None | |
| if adapter_path: | |
| print(f"Loading LoRA adapter from {adapter_path}...") | |
| # (Integration with Unsloth/Peft would go here) | |
| pass | |
| results = [] | |
| print(f"Starting evaluation on {len(test_samples)} samples...") | |
| for sample in test_samples: | |
| reward = run_episode(env_url, sample["commit_id"], model_client) | |
| results.append({ | |
| "commit_id": sample["commit_id"], | |
| "reward": reward, | |
| "cwe": sample.get("cwe_type") | |
| }) | |
| avg_reward = sum(r["reward"] for r in results) / len(results) | |
| print(f"Evaluation Complete. Average Reward: {avg_reward:.4f}") | |
| with open("eval_results.json", "w") as f: | |
| json.dump(results, f, indent=2) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--env-url", default="http://localhost:8000") | |
| parser.add_argument("--test-file", default="data/devign_test.jsonl") | |
| parser.add_argument("--adapter-path", default=None) | |
| args = parser.parse_args() | |
| evaluate(args.env_url, args.test_file, args.adapter_path) | |