Spaces:
Sleeping
Sleeping
refactor: extract agent_utils.py (shared prompt/commands/server utils), simplify reward to env+format, add audit.py
51a79ee | """ | |
| Shared utilities for training and evaluation scripts. | |
| Centralises the system prompt, user-prompt builder, and server lifecycle | |
| helpers so that train_data_centric.py and eval_data_centric.py stay in sync | |
| automatically. | |
| """ | |
| import subprocess | |
| import time | |
| import requests | |
| # ββ Valid commands (used for format-reward checking) βββββββββββββββββββββββββ | |
| VALID_COMMANDS = [ | |
| "inspect_dataset", "inspect_model", "query_analyst", | |
| "query_cleaner", "query_augmenter", "query_balancer", "query_validator", | |
| "apply", "reject", "undo", "validate", "submit", | |
| ] | |
| # ββ System prompt βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SYSTEM_PROMPT = """You are a Data-Centric AI agent. Your job is to improve a \ | |
| Machine learning dataset so a fixed classifier achieves higher accuracy. | |
| STRATEGY β use this order: | |
| 1. query_analyst to get a prioritised action plan (costs 2 budget total, worth it) | |
| 2. inspect_dataset to understand the data | |
| 3. query the recommended specialist (query_cleaner, query_augmenter, query_balancer) | |
| 4. apply the best recommendation by number (e.g. apply 1) | |
| 5. validate to check if accuracy improved | |
| 6. repeat until you hit the target or run low on budget | |
| 7. submit to finalize | |
| IMPORTANT RULES: | |
| - Start with query_analyst β it tells you the biggest problem to fix first. | |
| - Always query a specialist before applying. Never apply without querying first. | |
| - Check the Agreement signal after validate: DISAGREE means possible overfitting. | |
| - Validate after every 2-3 applies to track progress. | |
| - Do not spam validate β it costs budget after 3 uses. | |
| - query_validator costs 2 budget total β use only when suspicious of data quality. | |
| - submit when accuracy >= target or budget < 5. | |
| Reply with exactly ONE command per message. No explanation. Just the command.""" | |
| # ββ User prompt builder βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_user_prompt(obs: dict) -> str: | |
| """Build the per-step user prompt from an observation dict.""" | |
| improvement_needed = obs.get("target_accuracy", 0) - obs.get("current_accuracy", 0) | |
| return ( | |
| f"Current situation:\n" | |
| f"Accuracy: {obs.get('current_accuracy', 0):.1%} β " | |
| f"Target: {obs.get('target_accuracy', 0):.1%}\n" | |
| f"Still need: {max(0, improvement_needed):.1%} improvement\n" | |
| f"Quality score: {obs.get('estimated_quality', 0):.2f}/1.00\n" | |
| f"Rows preserved: {obs.get('rows_preserved_pct', 1.0):.1%}\n" | |
| f"Budget remaining: {obs.get('budget_remaining', 0)} steps\n" | |
| f"Free validates left: {obs.get('validate_calls_remaining', 0)}\n" | |
| f"Active query session: {obs.get('active_session', 'none')}\n\n" | |
| f"Last result:\n{str(obs.get('response', ''))[:400]}\n\n" | |
| f"What is your next action? (one command only)" | |
| ) | |
| # ββ Server lifecycle helpers ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def start_server(base_url: str = "http://localhost:8000") -> subprocess.Popen: | |
| """Start the environment server as a subprocess and wait until healthy.""" | |
| proc = subprocess.Popen( | |
| ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"], | |
| stdout=subprocess.DEVNULL, | |
| stderr=subprocess.DEVNULL, | |
| ) | |
| # Poll until ready (max 30 seconds) | |
| for i in range(30): | |
| try: | |
| r = requests.get(f"{base_url}/health", timeout=1) | |
| if r.status_code == 200: | |
| print(f"Server ready after {i + 1}s") | |
| return proc | |
| except Exception: | |
| pass | |
| time.sleep(1) | |
| proc.terminate() | |
| raise RuntimeError("Environment server failed to start in 30 seconds") | |
| def stop_server(proc: subprocess.Popen) -> None: | |
| """Terminate the environment server subprocess.""" | |
| proc.terminate() # cross-platform (SIGTERM on Linux, TerminateProcess on Windows) | |
| proc.wait() | |
| print("Server stopped.") | |