data-centric-env / eval_data_centric.py
Aswini-Kumar's picture
Update eval_data_centric.py
7af30cc verified
"""
eval_data_centric.py β€” Evaluation script for DataCentricEnv.
Runs two agents on identical seeds for fair comparison:
- Random Agent: picks valid commands at random (baseline)
- Trained Agent: uses the fine-tuned model from ./data-centric-adapter
Produces eval_results.json for use by plot_rewards.py.
"""
import json
import os
import random
import signal
import subprocess
import sys
import time
from typing import Optional
import requests # lightweight β€” always available
# WebSocket client for stateful episode rollouts
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from client import DataCentricEnv
from models import DataCentricAction
from agent_utils import (
VALID_COMMANDS, SYSTEM_PROMPT, build_user_prompt,
start_server, stop_server,
)
# ════════════════════════════════════════════════════════
# CONSTANTS
# ════════════════════════════════════════════════════════
BASE_URL = os.environ.get("ENV_URL", "http://localhost:8000")
ADAPTER_PATH = "./data-centric-adapter"
MAX_SEQ_LENGTH = 1024
EPISODES_PER_TASK = 10
TASKS = ["task_0_tutorial", "task_1_easy", "task_2_medium", "task_3_hard"]
# ════════════════════════════════════════════════════════
# MODEL LOADING
# ════════════════════════════════════════════════════════
def load_trained_model():
"""Lazy-load unsloth β€” only when adapter actually exists."""
import torch # noqa: F401
from unsloth import FastLanguageModel
if not os.path.exists(ADAPTER_PATH):
raise FileNotFoundError(
f"Adapter not found at {ADAPTER_PATH}. "
"Run train_data_centric.py (or train_colab.ipynb) first."
)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=ADAPTER_PATH,
max_seq_length=MAX_SEQ_LENGTH,
load_in_4bit=True,
dtype=None,
)
FastLanguageModel.for_inference(model)
return model, tokenizer
# ════════════════════════════════════════════════════════
# EPISODE METRICS
# ════════════════════════════════════════════════════════
def episode_metrics(
task: str,
seed: int,
final_obs: dict,
actions: list,
baseline_accuracy: float,
max_steps: int,
) -> dict:
"""Compute per-episode metrics for a single completed episode."""
final_accuracy = final_obs.get("current_accuracy", baseline_accuracy)
budget_remaining = final_obs.get("budget_remaining", 0)
target_accuracy = final_obs.get("target_accuracy", 1.0)
budget_used = max_steps - budget_remaining
accuracy_improvement = final_accuracy - baseline_accuracy
target_hit = final_accuracy >= target_accuracy
budget_efficiency = (
accuracy_improvement / max(budget_used, 1)
)
# Format rate: % actions that are valid commands
valid_count = sum(
1 for a in actions
if any(a.strip().startswith(cmd.split()[0]) for cmd in VALID_COMMANDS)
)
format_rate = valid_count / max(len(actions), 1)
# Strategy rate: % query→apply consecutive pairs
strategy_hits = 0
for i in range(1, len(actions)):
if (actions[i - 1].startswith("query_")
and actions[i].startswith("apply")):
strategy_hits += 1
strategy_rate = strategy_hits / max(len(actions) - 1, 1)
return {
"task": task,
"seed": seed,
"final_accuracy": round(final_accuracy, 4),
"baseline_accuracy": round(baseline_accuracy, 4),
"target_accuracy": round(target_accuracy, 4),
"accuracy_improvement": round(accuracy_improvement, 4),
"target_hit": target_hit,
"budget_used": budget_used,
"budget_efficiency": round(budget_efficiency, 6),
"format_rate": round(format_rate, 4),
"strategy_rate": round(strategy_rate, 4),
"n_actions": len(actions),
}
# ════════════════════════════════════════════════════════
# RANDOM AGENT
# ════════════════════════════════════════════════════════
def run_random_episode(task: str, seed: int) -> Optional[dict]:
"""Run one episode with a random agent using the WebSocket client."""
try:
with DataCentricEnv(base_url=BASE_URL).sync() as env:
r_reset = env.reset(task=task, seed=seed)
obs = r_reset.observation
baseline_accuracy = obs.baseline_accuracy
max_steps = obs.max_steps
actions = []
while not obs.done:
action = random.choice(VALID_COMMANDS)
actions.append(action)
try:
step_result = env.step(DataCentricAction(message=action))
obs = step_result.observation
except Exception:
break
return episode_metrics(
task, seed,
{"current_accuracy": obs.current_accuracy,
"budget_remaining": obs.budget_remaining,
"target_accuracy": obs.target_accuracy,
"done": obs.done},
actions, baseline_accuracy, max_steps
)
except Exception as e:
print(f" [random] Episode failed: {e}")
return None
# ════════════════════════════════════════════════════════
# TRAINED AGENT
# ════════════════════════════════════════════════════════
def run_trained_episode(
model, tokenizer, task: str, seed: int
) -> Optional[dict]:
"""Run one episode with the trained model using the WebSocket client."""
try:
with DataCentricEnv(base_url=BASE_URL).sync() as env:
r_reset = env.reset(task=task, seed=seed)
obs = r_reset.observation
baseline_accuracy = obs.baseline_accuracy
max_steps = obs.max_steps
actions = []
while not obs.done:
obs_dict = {
"current_accuracy": obs.current_accuracy,
"target_accuracy": obs.target_accuracy,
"estimated_quality": obs.estimated_quality,
"rows_preserved_pct": obs.rows_preserved_pct,
"budget_remaining": obs.budget_remaining,
"validate_calls_remaining":obs.validate_calls_remaining,
"active_session": obs.active_session,
"response": obs.response,
}
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": build_user_prompt(obs_dict)},
]
input_ids = tokenizer.apply_chat_template(
messages,
return_tensors="pt",
max_length=MAX_SEQ_LENGTH - 60,
truncation=True,
add_generation_prompt=True,
).to(model.device)
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=50,
temperature=0.1,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
action = tokenizer.decode(
output_ids[0][input_ids.shape[1]:],
skip_special_tokens=True,
).strip().split("\n")[0].strip()[:200]
actions.append(action)
try:
step_result = env.step(DataCentricAction(message=action))
obs = step_result.observation
except Exception as e:
break
return episode_metrics(
task, seed,
{"current_accuracy": obs.current_accuracy,
"budget_remaining": obs.budget_remaining,
"target_accuracy": obs.target_accuracy,
"done": obs.done},
actions, baseline_accuracy, max_steps
)
except Exception as e:
print(f" [trained] Episode failed: {e}")
return None
# ════════════════════════════════════════════════════════
# AGGREGATION
# ════════════════════════════════════════════════════════
def aggregate(episodes: list) -> dict:
"""Compute mean metrics across a list of episode result dicts."""
if not episodes:
return {}
keys = [
"accuracy_improvement", "target_hit", "budget_used",
"budget_efficiency", "format_rate", "strategy_rate",
]
return {
k: round(sum(ep[k] for ep in episodes) / len(episodes), 4)
for k in keys
}
def print_comparison_table(random_agg: dict, trained_agg: dict):
"""Print a formatted comparison table to stdout."""
def pct_change(r, t):
if r == 0:
return "β€”"
return f"{(t - r) / abs(r) * 100:+.0f}%"
def pp_change(r, t):
return f"{(t - r) * 100:+.0f}pp"
rows = [
("Accuracy gain", f"{random_agg.get('accuracy_improvement',0):.3f}",
f"{trained_agg.get('accuracy_improvement',0):.3f}",
pct_change(random_agg.get('accuracy_improvement',0),
trained_agg.get('accuracy_improvement',0))),
("Target hit rate", f"{random_agg.get('target_hit',0):.0%}",
f"{trained_agg.get('target_hit',0):.0%}",
pp_change(random_agg.get('target_hit',0),
trained_agg.get('target_hit',0))),
("Budget efficiency", f"{random_agg.get('budget_efficiency',0):.4f}",
f"{trained_agg.get('budget_efficiency',0):.4f}",
pct_change(random_agg.get('budget_efficiency',0),
trained_agg.get('budget_efficiency',0))),
("Format rate", "random",
f"{trained_agg.get('format_rate',0):.0%}", "β€”"),
("Strategy rate", "0%",
f"{trained_agg.get('strategy_rate',0):.0%}", "β€”"),
]
header = f"{'Metric':<20} {'Random':>12} {'Trained':>13} {'Improvement':>14}"
sep = "─" * len(header)
print(f"\n{sep}")
print(header)
print(sep)
for metric, rand, trained, improvement in rows:
print(f" {metric:<18} {rand:>12} {trained:>13} {improvement:>14}")
print(sep + "\n")
# ════════════════════════════════════════════════════════
# MAIN
# ════════════════════════════════════════════════════════
if __name__ == "__main__":
server_proc = start_server()
try:
print(f"\nLoading trained model from {ADAPTER_PATH}...")
model, tokenizer = load_trained_model()
# Use fixed seeds so both agents see identical tasks
seeds = list(range(EPISODES_PER_TASK))
results = {
"random": {"all_episodes": [], "by_task": {}},
"trained": {"all_episodes": [], "by_task": {}},
}
for task in TASKS:
print(f"\n{'='*50}")
print(f"Task: {task}")
print("─" * 50)
random_eps, trained_eps = [], []
for seed in seeds:
print(f" Seed {seed:2d}:", end=" ")
# Random agent
sys.stdout.write("[random] ")
sys.stdout.flush()
r_ep = run_random_episode(task, seed)
if r_ep:
random_eps.append(r_ep)
sys.stdout.write(
f"acc={r_ep['final_accuracy']:.3f} "
f"hit={'βœ“' if r_ep['target_hit'] else 'βœ—'} "
)
# Trained agent (same seed)
sys.stdout.write("[trained] ")
sys.stdout.flush()
t_ep = run_trained_episode(model, tokenizer, task, seed)
if t_ep:
trained_eps.append(t_ep)
sys.stdout.write(
f"acc={t_ep['final_accuracy']:.3f} "
f"hit={'βœ“' if t_ep['target_hit'] else 'βœ—'}"
)
print()
results["random"]["by_task"][task] = aggregate(random_eps)
results["trained"]["by_task"][task] = aggregate(trained_eps)
results["random"]["all_episodes"].extend(random_eps)
results["trained"]["all_episodes"].extend(trained_eps)
# Overall aggregates
results["random"]["overall"] = aggregate(results["random"]["all_episodes"])
results["trained"]["overall"] = aggregate(results["trained"]["all_episodes"])
# Print comparison table
print_comparison_table(
results["random"]["overall"],
results["trained"]["overall"],
)
# Print per-task breakdown
print("Per-task summary:")
for task in TASKS:
r = results["random"]["by_task"].get(task, {})
t = results["trained"]["by_task"].get(task, {})
print(
f" {task:<22} "
f"random: acc+{r.get('accuracy_improvement',0):.3f} "
f"hit={r.get('target_hit',0):.0%} | "
f"trained: acc+{t.get('accuracy_improvement',0):.3f} "
f"hit={t.get('target_hit',0):.0%}"
)
# Save full results
json.dump(results, open("eval_results.json", "w"), indent=2)
print("\nResults saved to eval_results.json")
print("Run plot_rewards.py to visualise.")
finally:
stop_server(server_proc)