Spaces:
Sleeping
Sleeping
Redesign reward for discrimination: efficiency multiplier, strict penalties, stretch bonus, start at level 1
46f0850 | """Main DataCentric RL Environment.""" | |
| import logging | |
| import time | |
| from copy import deepcopy | |
| from typing import Any, Dict, List, Optional | |
| from uuid import uuid4 | |
| import pandas as pd | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| try: | |
| from ..models import DataCentricAction, DataCentricObservation | |
| except ImportError: | |
| try: | |
| from models import DataCentricAction, DataCentricObservation | |
| except ImportError: | |
| import sys as _sys, os as _os | |
| _sys.path.insert(0, _os.path.dirname(_os.path.dirname(_os.path.abspath(__file__)))) | |
| from models import DataCentricAction, DataCentricObservation | |
| try: | |
| from .anti_exploit import ( | |
| AntiExploitState, assert_ground_truth_intact, | |
| check_and_truncate_input, check_apply_allowed, | |
| check_catastrophic_data_loss, check_episode_timeout, | |
| check_validate_cooldown, get_validate_reward, record_apply, | |
| record_non_validate_step, record_validate, reset_session_apply_state, | |
| validate_calls_remaining, | |
| ) | |
| from .dataset_generator import TASK_CONFIGS, generate_dataset | |
| from .grader import ( | |
| compute_accuracy_reward, compute_efficiency_reward, | |
| compute_lightweight_score, compute_preservation_reward, | |
| compute_process_reward, compute_step_reward, compute_total_reward, | |
| ) | |
| from .model_evaluator import ModelEvaluator | |
| from .specialist_agents import ( | |
| AugmenterAgent, AnalystAgent, BalancerAgent, CleanerAgent, | |
| SessionRegistry, ValidatorAgent, compute_drift, format_drift_summary, | |
| ) | |
| except ImportError: | |
| from server.anti_exploit import ( | |
| AntiExploitState, assert_ground_truth_intact, | |
| check_and_truncate_input, check_apply_allowed, | |
| check_catastrophic_data_loss, check_episode_timeout, | |
| check_validate_cooldown, get_validate_reward, record_apply, | |
| record_non_validate_step, record_validate, reset_session_apply_state, | |
| validate_calls_remaining, | |
| ) | |
| from server.dataset_generator import TASK_CONFIGS, generate_dataset | |
| from server.grader import ( | |
| compute_accuracy_reward, compute_efficiency_reward, | |
| compute_lightweight_score, compute_preservation_reward, | |
| compute_process_reward, compute_step_reward, compute_total_reward, | |
| ) | |
| from server.model_evaluator import ModelEvaluator | |
| from server.specialist_agents import ( | |
| AugmenterAgent, AnalystAgent, BalancerAgent, CleanerAgent, | |
| SessionRegistry, ValidatorAgent, compute_drift, format_drift_summary, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| AVAILABLE_COMMANDS = """Available commands: | |
| inspect_dataset β shape, dtypes, missing, class distribution | |
| inspect_model β accuracy (RF + LR), F1, feature importance | |
| query_analyst β holistic diagnosis + prioritised action plan (costs 2 budget total) | |
| query_cleaner β get cleaning recommendations | |
| query_augmenter [class] β get augmentation suggestions | |
| query_balancer β get resampling recommendations | |
| query_validator β check rule violations (costs 2 budget total) | |
| apply [id] β apply recommendation by ID | |
| reject [id] β reject a recommendation | |
| undo β revert last apply (max 3 levels) | |
| validate β retrain and score (cooldown applies) | |
| submit β finalize episode""" | |
| class DataCentricEnvironment(Environment): | |
| """Data-Centric AI RL Environment.""" | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self): | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._ground_truth: Optional[pd.DataFrame] = None | |
| self._working_copy: Optional[pd.DataFrame] = None | |
| self._metadata: Dict[str, Any] = {} | |
| self._action_history: List[str] = [] | |
| self._exploit: Optional[AntiExploitState] = None | |
| # fast_mode=True: uses n_estimators=20 for training rollouts (~4x faster) | |
| self._evaluator = ModelEvaluator(fast_mode=True) | |
| self._session_registry = SessionRegistry() | |
| self._cleaner = CleanerAgent() | |
| self._augmenter = AugmenterAgent() | |
| self._balancer = BalancerAgent() | |
| self._validator = ValidatorAgent() | |
| self._analyst = AnalystAgent() | |
| self._current_accuracy: float = 0.0 | |
| self._previous_accuracy: float = 0.0 | |
| self._active_session: str = "none" | |
| self._task: str = "task_0_tutorial" | |
| # Snapshot stack for undo command (max 3 snapshots) | |
| self._dataset_history: List[pd.DataFrame] = [] | |
| self._max_history: int = 3 | |
| # ββ reset ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reset(self, task: str = "task_0_tutorial", seed: int = 42) -> DataCentricObservation: | |
| self._task = task if task in TASK_CONFIGS else "task_0_tutorial" | |
| cfg = TASK_CONFIGS[self._task] | |
| self._ground_truth, self._working_copy, self._metadata = generate_dataset( | |
| self._task, seed=seed | |
| ) | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._action_history = [] | |
| self._exploit = AntiExploitState( | |
| episode_start_time=time.time(), | |
| ground_truth_row_count=len(self._ground_truth), | |
| ) | |
| self._evaluator.invalidate_cache() | |
| self._session_registry = SessionRegistry() | |
| self._active_session = "none" | |
| self._dataset_history = [] # clear snapshot stack on reset | |
| reset_session_apply_state(self._exploit) | |
| # Store episode-start missing count for quality score baseline | |
| self._metadata["initial_missing"] = int(self._working_copy.isnull().sum().sum()) | |
| self._metadata["baseline_accuracy"] = cfg["baseline_accuracy"] | |
| baseline = cfg["baseline_accuracy"] | |
| self._current_accuracy = baseline | |
| self._previous_accuracy = baseline | |
| quality = compute_lightweight_score( | |
| self._working_copy, self._ground_truth, | |
| self._metadata["original_length"], self._metadata["col_meta"], | |
| initial_missing=self._metadata["initial_missing"], | |
| ) | |
| wc = self._working_copy | |
| return DataCentricObservation( | |
| response=( | |
| f"Episode started: {self._task}\n" | |
| f"Baseline accuracy: {baseline:.4f} | Target: {cfg['target_accuracy']:.4f}\n" | |
| f"Dataset: {len(wc)} rows Γ {len(wc.columns)-1} features\n" | |
| f"Budget: {cfg['budget']} steps\n\n{AVAILABLE_COMMANDS}" | |
| ), | |
| current_accuracy=baseline, | |
| baseline_accuracy=baseline, | |
| target_accuracy=cfg["target_accuracy"], | |
| estimated_quality=quality, | |
| dataset_shape=f"{len(wc)} rows Γ {len(wc.columns)-1} columns", | |
| rows_preserved_pct=1.0, | |
| budget_remaining=cfg["budget"], | |
| step_number=0, | |
| max_steps=cfg["budget"], | |
| active_session="none", | |
| validate_calls_remaining=validate_calls_remaining(self._exploit), | |
| done=False, | |
| reward=0.0, | |
| ) | |
| # ββ step βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def step(self, action: DataCentricAction) -> DataCentricObservation: | |
| if self._working_copy is None: | |
| return self._error_obs("Call reset() first.") | |
| # Episode timeout | |
| timeout, tmsg = check_episode_timeout(self._exploit) | |
| if timeout: | |
| return self._do_submit(penalty=-0.10, extra_msg=tmsg) | |
| # Input truncation | |
| raw_msg = action.message | |
| msg, trunc_penalty, was_truncated = check_and_truncate_input(raw_msg) | |
| if was_truncated: | |
| logger.warning("Input truncated.") | |
| cfg = TASK_CONFIGS[self._task] | |
| self._state.step_count += 1 | |
| step_num = self._state.step_count | |
| budget_remaining = cfg["budget"] - step_num | |
| cmd_parts = msg.strip().split() | |
| cmd = cmd_parts[0].lower() if cmd_parts else "" | |
| # Out of budget β force submit | |
| if budget_remaining < 0: | |
| return self._do_submit(penalty=0.0, extra_msg="Budget exhausted.") | |
| # Record action | |
| self._action_history.append(msg) | |
| # Process reward component (computed for all actions) | |
| r_process = compute_process_reward(self._action_history[:-1], msg) | |
| # Route command | |
| if cmd == "inspect_dataset": | |
| obs = self._cmd_inspect_dataset(step_num, budget_remaining, r_process, trunc_penalty) | |
| elif cmd == "inspect_model": | |
| obs = self._cmd_inspect_model(step_num, budget_remaining, r_process, trunc_penalty) | |
| elif cmd == "query_cleaner": | |
| obs = self._cmd_query_cleaner(step_num, budget_remaining, r_process, trunc_penalty) | |
| elif cmd == "query_augmenter": | |
| cls = cmd_parts[1] if len(cmd_parts) > 1 else None | |
| obs = self._cmd_query_augmenter(cls, step_num, budget_remaining, r_process, trunc_penalty) | |
| elif cmd == "query_balancer": | |
| obs = self._cmd_query_balancer(step_num, budget_remaining, r_process, trunc_penalty) | |
| elif cmd == "query_analyst": | |
| obs = self._cmd_query_analyst(step_num, budget_remaining, r_process, trunc_penalty) | |
| elif cmd == "query_validator": | |
| obs = self._cmd_query_validator(step_num, budget_remaining, r_process, trunc_penalty) | |
| elif cmd == "apply": | |
| try: | |
| rec_id = int(cmd_parts[1]) if len(cmd_parts) > 1 else -1 | |
| except ValueError: | |
| rec_id = -1 | |
| obs = self._cmd_apply(rec_id, step_num, budget_remaining, r_process, trunc_penalty) | |
| elif cmd == "reject": | |
| try: | |
| rec_id = int(cmd_parts[1]) if len(cmd_parts) > 1 else -1 | |
| except ValueError: | |
| rec_id = -1 | |
| obs = self._cmd_reject(rec_id, step_num, budget_remaining, r_process, trunc_penalty) | |
| elif cmd == "validate": | |
| obs = self._cmd_validate(step_num, budget_remaining, r_process, trunc_penalty) | |
| elif cmd == "submit": | |
| obs = self._do_submit() | |
| elif cmd == "undo": | |
| obs = self._cmd_undo(step_num, budget_remaining, r_process, trunc_penalty) | |
| else: | |
| obs = self._unknown_cmd_obs(msg, step_num, budget_remaining, r_process + trunc_penalty) | |
| if cmd != "validate": | |
| record_non_validate_step(self._exploit) | |
| return obs | |
| # ββ command handlers βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _cmd_inspect_dataset(self, step, budget, r_process, trunc_pen) -> DataCentricObservation: | |
| wc = self._working_copy | |
| orig_len = self._metadata["original_length"] | |
| missing = wc.isnull().sum() | |
| missing_str = "\n".join(f" {c}: {v}" for c, v in missing.items() if v > 0) or " None" | |
| vc = wc["target"].value_counts().sort_index() | |
| class_str = ", ".join(f"class {k}: {v}" for k, v in vc.items()) | |
| rows_pct = len(wc) / orig_len | |
| response = ( | |
| f"=== Dataset Inspection ===\n" | |
| f"Shape: {len(wc)} rows Γ {len(wc.columns)-1} features\n" | |
| f"Original rows: {orig_len} | Preserved: {rows_pct*100:.1f}%\n" | |
| f"Duplicates: {wc.duplicated().sum()}\n" | |
| f"Missing values:\n{missing_str}\n" | |
| f"Class distribution: {class_str}\n" | |
| f"Dtypes: {dict(wc.dtypes.astype(str))}" | |
| ) | |
| reward = compute_total_reward(0.0, r_process, 0.0) + trunc_pen | |
| return self._make_obs(response, step, budget, reward) | |
| def _cmd_inspect_model(self, step, budget, r_process, trunc_pen) -> DataCentricObservation: | |
| acc, per_class, from_cache, lr_acc = self._evaluator.evaluate( | |
| self._working_copy, self._ground_truth | |
| ) | |
| cache_label = " (cached)" if from_cache else "" | |
| lines = [f"=== Model Inspection{cache_label} ===", | |
| f"RF Accuracy: {acc:.4f}", | |
| f"LR Accuracy: {lr_acc:.4f} (secondary β diagnostic only)"] | |
| for cls, metrics in per_class.items(): | |
| if isinstance(metrics, dict): | |
| lines.append( | |
| f" Class {cls}: precision={metrics.get('precision',0):.3f} " | |
| f"recall={metrics.get('recall',0):.3f} " | |
| f"f1={metrics.get('f1-score',0):.3f}" | |
| ) | |
| feat_text = self._evaluator.feature_importance_text() | |
| if feat_text: | |
| lines.append(feat_text) | |
| response = "\n".join(lines) | |
| reward = compute_total_reward(0.0, r_process, 0.0) + trunc_pen | |
| return self._make_obs(response, step, budget, reward) | |
| def _cmd_query_cleaner(self, step, budget, r_process, trunc_pen) -> DataCentricObservation: | |
| reset_session_apply_state(self._exploit) | |
| recs = self._cleaner.query( | |
| self._working_copy, self._session_registry, self._metadata["col_meta"] | |
| ) | |
| self._active_session = f"cleaner:{self._session_registry.current_session_id[:8]}" | |
| lines = ["=== Cleaner Recommendations ==="] | |
| for r in recs: | |
| lines.append( | |
| f"[{r.id}] {r.description}\n" | |
| f" type={r.action_type} impact={r.estimated_impact:+.3f} " | |
| f"confidence={r.confidence:.2f}" | |
| ) | |
| response = "\n".join(lines) if recs else "No cleaning issues detected." | |
| reward = compute_total_reward(0.0, r_process, 0.0) + trunc_pen | |
| return self._make_obs(response, step, budget, reward) | |
| def _cmd_query_augmenter(self, cls, step, budget, r_process, trunc_pen) -> DataCentricObservation: | |
| reset_session_apply_state(self._exploit) | |
| recs = self._augmenter.query(self._working_copy, self._session_registry, cls) | |
| self._active_session = f"augmenter:{self._session_registry.current_session_id[:8]}" | |
| lines = ["=== Augmenter Recommendations ==="] | |
| for r in recs: | |
| lines.append( | |
| f"[{r.id}] {r.description}\n" | |
| f" type={r.action_type} impact={r.estimated_impact:+.3f} " | |
| f"confidence={r.confidence:.2f}" | |
| ) | |
| response = "\n".join(lines) if recs else "No augmentation needed." | |
| reward = compute_total_reward(0.0, r_process, 0.0) + trunc_pen | |
| return self._make_obs(response, step, budget, reward) | |
| def _cmd_query_balancer(self, step, budget, r_process, trunc_pen) -> DataCentricObservation: | |
| reset_session_apply_state(self._exploit) | |
| recs = self._balancer.query(self._working_copy, self._session_registry) | |
| self._active_session = f"balancer:{self._session_registry.current_session_id[:8]}" | |
| lines = ["=== Balancer Recommendations ==="] | |
| for r in recs: | |
| lines.append( | |
| f"[{r.id}] {r.description}\n" | |
| f" type={r.action_type} impact={r.estimated_impact:+.3f} " | |
| f"confidence={r.confidence:.2f}" | |
| ) | |
| response = "\n".join(lines) if recs else "Dataset is already balanced." | |
| reward = compute_total_reward(0.0, r_process, 0.0) + trunc_pen | |
| return self._make_obs(response, step, budget, reward) | |
| def _cmd_query_analyst(self, step, budget, r_process, trunc_pen) -> DataCentricObservation: | |
| """Holistic diagnosis + prioritised action plan. Costs 2 budget total (1 cmd step + 1 internal).""" | |
| # Costs 1 extra budget step | |
| self._state.step_count += 1 | |
| plan = self._analyst.query( | |
| self._working_copy, | |
| self._metadata["col_meta"], | |
| self._current_accuracy, | |
| TASK_CONFIGS[self._task]["target_accuracy"], | |
| budget - 1, | |
| ) | |
| response = f"=== Analyst Report (costs 1 budget) ===\n{plan}" | |
| reward = compute_total_reward(0.0, r_process + 0.02, 0.0) + trunc_pen # small bonus for planning | |
| budget_remaining = TASK_CONFIGS[self._task]["budget"] - self._state.step_count | |
| return self._make_obs(response, step, budget_remaining, reward) | |
| def _cmd_query_validator(self, step, budget, r_process, trunc_pen) -> DataCentricObservation: | |
| # Costs 2 budget | |
| self._state.step_count += 1 | |
| violations = self._validator.query(self._working_copy, self._metadata["col_meta"]) | |
| lines = ["=== Validator Report (costs 2 budget) ==="] | |
| if violations: | |
| for v in violations: | |
| lines.append( | |
| f" [{v.severity}] [{v.column}] rule={v.rule} count={v.count}\n {v.description}" | |
| ) | |
| else: | |
| lines.append(" No rule violations found.") | |
| response = "\n".join(lines) | |
| reward = compute_total_reward(0.0, r_process, 0.0) + trunc_pen | |
| budget_remaining = TASK_CONFIGS[self._task]["budget"] - self._state.step_count | |
| return self._make_obs(response, step, budget_remaining, reward) | |
| def _cmd_apply(self, rec_id, step, budget, r_process, trunc_pen) -> DataCentricObservation: | |
| if rec_id < 1: | |
| # Error: return 0 reward (no penalty, no bonus) | |
| return self._make_obs("Error: invalid recommendation ID.", step, budget, 0.0) | |
| # Check apply allowed (duplicate / session limit) β 0 reward on error | |
| allowed, err = check_apply_allowed(rec_id, self._exploit) | |
| if not allowed: | |
| return self._make_obs(f"Error: {err}", step, budget, 0.0) | |
| # Get recommendation (staleness check) β 0 reward, no penalty | |
| rec = self._session_registry.get(rec_id, self._session_registry.current_session_id) | |
| if rec is None: | |
| return self._make_obs( | |
| f"Error: stale recommendation ID {rec_id}. Please re-query for fresh recommendations.", | |
| step, budget, 0.0 | |
| ) | |
| # Capture quality before mutation for step reward | |
| quality_before = compute_lightweight_score( | |
| self._working_copy, self._ground_truth, | |
| self._metadata["original_length"], self._metadata["col_meta"], | |
| initial_missing=self._metadata.get("initial_missing"), | |
| ) | |
| # Execute payload | |
| payload = rec._payload | |
| action_type = payload.get("action", "") | |
| wc = self._working_copy | |
| orig_len = self._metadata["original_length"] | |
| pre_rows = len(wc) | |
| pre_missing = int(wc.isnull().sum().sum()) | |
| pre_dups = int(wc.duplicated().sum()) | |
| # Save snapshot for undo before mutating | |
| self._dataset_history.append(self._working_copy.copy()) | |
| if len(self._dataset_history) > self._max_history: | |
| self._dataset_history.pop(0) | |
| try: | |
| if action_type == "fill_missing": | |
| col = payload["column"] | |
| strategy = payload.get("strategy", "mean") # honor smarter CleanerAgent choice | |
| numeric = pd.to_numeric(wc[col], errors="coerce") | |
| if strategy == "median": | |
| fill_val = float(numeric.median()) | |
| else: | |
| fill_val = float(numeric.mean()) | |
| wc[col] = numeric.fillna(fill_val) | |
| self._working_copy = wc | |
| elif action_type == "remove_duplicates": | |
| self._working_copy = wc.drop_duplicates().reset_index(drop=True) | |
| elif action_type == "fix_type_errors": | |
| col = payload["column"] | |
| numeric = pd.to_numeric(wc[col], errors="coerce") | |
| mean_val = float(numeric.mean()) | |
| wc[col] = numeric.fillna(mean_val) | |
| self._working_copy = wc | |
| elif action_type == "augment_class": | |
| cls_int = payload["class"] | |
| n_synth = payload["n_synth"] | |
| cls_rows = wc[wc["target"] == cls_int] | |
| if len(cls_rows) > 0: | |
| synth = cls_rows.sample(n=n_synth, replace=True, random_state=42) | |
| noise_cols = [c for c in synth.columns if c != "target"] | |
| for c in noise_cols: | |
| try: | |
| synth[c] = pd.to_numeric(synth[c], errors="coerce") | |
| synth[c] = synth[c] + synth[c].std() * 0.1 | |
| except Exception: | |
| pass | |
| self._working_copy = pd.concat([wc, synth], ignore_index=True) | |
| elif action_type == "oversample": | |
| cls_int = payload["class"] | |
| target_count = payload["target_count"] | |
| cls_rows = wc[wc["target"] == cls_int] | |
| n_needed = max(0, target_count - len(cls_rows)) | |
| if n_needed > 0: | |
| extra = cls_rows.sample(n=n_needed, replace=True, random_state=42) | |
| self._working_copy = pd.concat([wc, extra], ignore_index=True) | |
| elif action_type == "undersample": | |
| cls_int = payload["class"] | |
| target_count = payload["target_count"] | |
| cls_rows = wc[wc["target"] == cls_int] | |
| if len(cls_rows) > target_count: | |
| keep = cls_rows.sample(n=target_count, random_state=42) | |
| other = wc[wc["target"] != cls_int] | |
| self._working_copy = pd.concat([keep, other], ignore_index=True) | |
| elif action_type == "remove_outlier_rows": | |
| col = payload["column"] | |
| pct = payload.get("pct", 5) | |
| try: | |
| numeric = pd.to_numeric(wc[col], errors="coerce") | |
| threshold = float(numeric.quantile(pct / 100)) | |
| self._working_copy = wc[pd.to_numeric(wc[col], errors="coerce") >= threshold].reset_index(drop=True) | |
| except Exception: | |
| pass | |
| except Exception as exc: | |
| logger.exception("Error executing apply: %s", exc) | |
| return self._make_obs(f"Error executing recommendation: {exc}", step, budget, 0.0) | |
| record_apply(rec_id, self._exploit) | |
| # Ground truth immutability assertion β must never change | |
| gt_ok, gt_msg = assert_ground_truth_intact( | |
| len(self._ground_truth), self._exploit.ground_truth_row_count | |
| ) | |
| if not gt_ok: | |
| logger.critical(gt_msg) | |
| return self._do_submit(penalty=-1.0, extra_msg=gt_msg) | |
| wc_new = self._working_copy | |
| post_rows = len(wc_new) | |
| post_missing = int(wc_new.isnull().sum().sum()) | |
| post_dups = int(wc_new.duplicated().sum()) | |
| rows_pct = post_rows / orig_len | |
| # Catastrophic data loss | |
| catastro, cmsg = check_catastrophic_data_loss(post_rows, orig_len) | |
| if catastro: | |
| return self._do_submit(penalty=-0.40, extra_msg=cmsg) | |
| # Preservation reward | |
| r_preservation = compute_preservation_reward(post_rows, orig_len) | |
| # Lightweight quality (use episode-start missing count as denominator) | |
| quality = compute_lightweight_score( | |
| wc_new, self._ground_truth, orig_len, self._metadata["col_meta"], | |
| initial_missing=self._metadata.get("initial_missing"), | |
| ) | |
| # Build rich feedback with drift detection | |
| cfg = TASK_CONFIGS[self._task] | |
| missing_status = "OK" if post_missing == 0 else f"{post_missing} remaining" | |
| dup_status = "OK" if post_dups == 0 else f"{post_dups} remaining" | |
| drift = compute_drift(self._working_copy, self._ground_truth) | |
| drift_summary = format_drift_summary(drift) | |
| response = ( | |
| f"Applied: {action_type} [{rec.description[:80]}]\n\n" | |
| f"Dataset health check:\n" | |
| f" Missing values: {missing_status} (was {pre_missing})\n" | |
| f" Duplicates: {dup_status} (was {pre_dups})\n" | |
| f" Row count: {post_rows}/{orig_len} ({rows_pct*100:.1f}% preserved)\n" | |
| f" {drift_summary}\n\n" | |
| f"Estimated quality score: {quality:.4f}\n" | |
| f"Budget remaining: {budget}" | |
| ) | |
| reward = compute_total_reward( | |
| 0.0, r_process, r_preservation, | |
| reward_step=compute_step_reward( | |
| f"apply {rec_id}", quality_before, quality, rows_pct | |
| ), | |
| ) + trunc_pen | |
| self._evaluator.invalidate_cache() | |
| return self._make_obs(response, step, budget, reward, quality=quality, | |
| rows_pct=rows_pct) | |
| def _cmd_reject(self, rec_id, step, budget, r_process, trunc_pen) -> DataCentricObservation: | |
| response = ( | |
| f"Recommendation {rec_id} rejected. It will not appear in future queries." | |
| if rec_id >= 1 else "Error: invalid recommendation ID." | |
| ) | |
| reward = compute_total_reward(0.0, r_process + 0.01, 0.0) + trunc_pen | |
| return self._make_obs(response, step, budget, reward) | |
| def _cmd_undo(self, step, budget, r_process, trunc_pen) -> DataCentricObservation: | |
| """Restore previous dataset state (max 3 levels deep).""" | |
| if self._dataset_history: | |
| self._working_copy = self._dataset_history.pop() | |
| self._evaluator.invalidate_cache() | |
| orig_len = self._metadata["original_length"] | |
| rows_pct = len(self._working_copy) / orig_len | |
| quality = compute_lightweight_score( | |
| self._working_copy, self._ground_truth, | |
| orig_len, self._metadata["col_meta"], | |
| initial_missing=self._metadata.get("initial_missing"), | |
| ) | |
| response = ( | |
| f"Undo successful. Reverted to previous dataset state.\n" | |
| f"Row count: {len(self._working_copy)}/{orig_len} ({rows_pct*100:.1f}% preserved)\n" | |
| f"Estimated quality: {quality:.4f}\n" | |
| f"Snapshots remaining: {len(self._dataset_history)}" | |
| ) | |
| reward = compute_total_reward(0.0, r_process - 0.03, 0.0) + trunc_pen # small cost | |
| else: | |
| response = "Nothing to undo. No previous state available." | |
| reward = compute_total_reward(0.0, r_process - 0.05, 0.0) + trunc_pen # larger cost | |
| return self._make_obs(response, step, budget, reward) | |
| def _cmd_validate(self, step, budget, r_process, trunc_pen) -> DataCentricObservation: | |
| allowed, cooldown_msg = check_validate_cooldown(self._exploit) | |
| if not allowed: | |
| return self._make_obs(cooldown_msg, step, budget, 0.0) | |
| prev_rf = self._evaluator.last_accuracy | |
| prev_lr = self._evaluator.last_lr_accuracy | |
| acc, per_class, from_cache, lr_acc = self._evaluator.evaluate( | |
| self._working_copy, self._ground_truth | |
| ) | |
| cache_label = " (cached)" if from_cache else "" | |
| if from_cache: | |
| r_validate = 0.0 | |
| else: | |
| r_validate = get_validate_reward(self._exploit) | |
| record_validate(self._exploit) | |
| r_accuracy = compute_accuracy_reward( | |
| acc, self._current_accuracy, | |
| self._metadata["baseline_accuracy"], | |
| TASK_CONFIGS[self._task]["target_accuracy"], | |
| ) | |
| self._previous_accuracy = self._current_accuracy | |
| self._current_accuracy = acc | |
| target = TASK_CONFIGS[self._task]["target_accuracy"] | |
| agreement = self._evaluator.agreement_signal(acc, lr_acc, prev_rf, prev_lr) | |
| feat_text = self._evaluator.feature_importance_text() | |
| lines = [ | |
| f"=== Validate{cache_label} ===", | |
| f"RF Accuracy: {acc:.4f} (primary)", | |
| f"LR Accuracy: {lr_acc:.4f} (secondary)", | |
| f"Agreement: {agreement}", | |
| ] | |
| for cls, metrics in per_class.items(): | |
| if isinstance(metrics, dict): | |
| lines.append( | |
| f" Class {cls}: p={metrics.get('precision',0):.3f} " | |
| f"r={metrics.get('recall',0):.3f} f1={metrics.get('f1-score',0):.3f}" | |
| ) | |
| lines.append(f"Target: {target:.4f} | {'HIT β' if acc >= target else 'Not yet'}") | |
| if feat_text: | |
| lines.append(feat_text) | |
| response = "\n".join(lines) | |
| reward = compute_total_reward(r_accuracy, r_process + r_validate, 0.0) + trunc_pen | |
| return self._make_obs(response, step, budget, reward) | |
| # ββ submit ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _do_submit(self, penalty: float = 0.0, extra_msg: str = "") -> DataCentricObservation: | |
| cfg = TASK_CONFIGS[self._task] | |
| orig_len = self._metadata["original_length"] | |
| budget_remaining = cfg["budget"] - self._state.step_count | |
| # Final accuracy | |
| acc, per_class, _, lr_acc = self._evaluator.evaluate( | |
| self._working_copy, self._ground_truth | |
| ) | |
| self._current_accuracy = acc | |
| r_accuracy = compute_accuracy_reward( | |
| acc, self._previous_accuracy, | |
| cfg["baseline_accuracy"], cfg["target_accuracy"], | |
| is_submit=True, | |
| budget_used=self._state.step_count, | |
| budget_total=cfg["budget"], | |
| ) | |
| r_process = compute_process_reward(self._action_history[:-1], "submit") | |
| r_preservation = compute_preservation_reward(len(self._working_copy), orig_len) | |
| r_efficiency = compute_efficiency_reward( | |
| acc, cfg["baseline_accuracy"], cfg["budget"], max(budget_remaining, 0), | |
| target_accuracy=cfg["target_accuracy"], | |
| ) | |
| total = compute_total_reward(r_accuracy, r_process, r_preservation, r_efficiency) | |
| total += penalty | |
| hit = acc >= cfg["target_accuracy"] | |
| response = ( | |
| f"{'=' * 40}\n" | |
| f"EPISODE COMPLETE\n" | |
| f"{'=' * 40}\n" | |
| f"Final accuracy: {acc:.4f}\n" | |
| f"Target accuracy: {cfg['target_accuracy']:.4f}\n" | |
| f"Baseline: {cfg['baseline_accuracy']:.4f}\n" | |
| f"Result: {'TARGET HIT β' if hit else 'Target not reached'}\n\n" | |
| f"Reward breakdown:\n" | |
| f" Accuracy: {r_accuracy:+.4f}\n" | |
| f" Process: {r_process:+.4f}\n" | |
| f" Preservation: {r_preservation:+.4f}\n" | |
| f" Efficiency: {r_efficiency:+.4f}\n" | |
| f" Penalty: {penalty:+.4f}\n" | |
| f" TOTAL: {total:+.4f}\n" | |
| + (f"\n{extra_msg}" if extra_msg else "") | |
| ) | |
| quality = compute_lightweight_score( | |
| self._working_copy, self._ground_truth, | |
| orig_len, self._metadata["col_meta"], | |
| ) | |
| rows_pct = len(self._working_copy) / orig_len | |
| return DataCentricObservation( | |
| response=response, | |
| current_accuracy=acc, | |
| baseline_accuracy=cfg["baseline_accuracy"], | |
| target_accuracy=cfg["target_accuracy"], | |
| estimated_quality=quality, | |
| dataset_shape=f"{len(self._working_copy)} rows Γ {len(self._working_copy.columns)-1} columns", | |
| rows_preserved_pct=rows_pct, | |
| budget_remaining=max(budget_remaining, 0), | |
| step_number=self._state.step_count, | |
| max_steps=cfg["budget"], | |
| active_session=self._active_session, | |
| validate_calls_remaining=validate_calls_remaining(self._exploit), | |
| done=True, | |
| reward=round(total, 4), | |
| ) | |
| # ββ helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _make_obs(self, response: str, step: int, budget: int, reward: float, | |
| quality: Optional[float] = None, rows_pct: Optional[float] = None | |
| ) -> DataCentricObservation: | |
| cfg = TASK_CONFIGS[self._task] | |
| orig_len = self._metadata["original_length"] | |
| wc = self._working_copy | |
| if quality is None: | |
| quality = compute_lightweight_score( | |
| wc, self._ground_truth, orig_len, self._metadata["col_meta"], | |
| initial_missing=self._metadata.get("initial_missing"), | |
| ) | |
| if rows_pct is None: | |
| rows_pct = len(wc) / orig_len | |
| return DataCentricObservation( | |
| response=response, | |
| current_accuracy=self._current_accuracy, | |
| baseline_accuracy=cfg["baseline_accuracy"], | |
| target_accuracy=cfg["target_accuracy"], | |
| estimated_quality=quality, | |
| dataset_shape=f"{len(wc)} rows Γ {len(wc.columns)-1} columns", | |
| rows_preserved_pct=rows_pct, | |
| budget_remaining=max(budget, 0), | |
| step_number=step, | |
| max_steps=cfg["budget"], | |
| active_session=self._active_session, | |
| validate_calls_remaining=validate_calls_remaining(self._exploit), | |
| done=False, | |
| reward=round(reward, 4), | |
| ) | |
| def _error_obs(self, msg: str) -> DataCentricObservation: | |
| return DataCentricObservation(response=msg, done=False, reward=0.0) | |
| def _unknown_cmd_obs(self, msg: str, step: int, budget: int, | |
| reward: float) -> DataCentricObservation: | |
| return self._make_obs( | |
| f"Unknown command: '{msg}'\n\n{AVAILABLE_COMMANDS}", step, budget, reward | |
| ) | |
| def state(self) -> State: | |
| return self._state | |