#!/usr/bin/env python3 """ Quick validation script for Medical Diagnostic Environment This script validates that the core environment works correctly without requiring the server to be running or external dependencies beyond models. Run with: python validate.py """ import sys import traceback from pathlib import Path from typing import Dict, List # Add parent directory to path sys.path.insert(0, str(Path(__file__).parent)) from models import DiagnosticAction, PatientObservation, ClinicalState from server.environment import MedicalDiagnosticEnvironment from server.medical_data import ( PATIENT_CASES, calculate_question_reward, calculate_test_reward, calculate_diagnosis_accuracy, ) class ValidationResult: """Result of a validation check""" def __init__(self, name: str, passed: bool, error: str = None): self.name = name self.passed = passed self.error = error def __str__(self): status = "PASS" if self.passed else "FAIL" msg = f"{status}: {self.name}" if self.error: msg += f"\n Error: {self.error}" return msg def validate_imports() -> ValidationResult: """Check that all imports work""" try: from models import DiagnosticAction, PatientObservation, ClinicalState from server.environment import MedicalDiagnosticEnvironment from server.medical_data import ( calculate_question_reward, calculate_test_reward, calculate_diagnosis_accuracy, ) return ValidationResult("Imports", True) except Exception as e: return ValidationResult("Imports", False, str(e)) def validate_model_creation() -> ValidationResult: """Check that models can be instantiated""" try: action = DiagnosticAction( action_type="ask_question", question="Test question?" ) assert action.action_type == "ask_question" assert action.question == "Test question?" return ValidationResult("Model Creation", True) except Exception as e: return ValidationResult("Model Creation", False, str(e)) def validate_environment_init() -> ValidationResult: """Check that environment initializes""" try: env = MedicalDiagnosticEnvironment() assert env is not None assert hasattr(env, "reset") assert hasattr(env, "step") return ValidationResult("Environment Initialization", True) except Exception as e: return ValidationResult("Environment Initialization", False, str(e)) def validate_reset_all_difficulties() -> ValidationResult: """Check that reset works for all difficulties""" try: env = MedicalDiagnosticEnvironment() for difficulty in ["easy", "medium", "hard"]: obs = env.reset(difficulty=difficulty) assert obs is not None assert env.current_difficulty == difficulty assert env.current_case_id is not None return ValidationResult("Reset All Difficulties", True) except Exception as e: return ValidationResult("Reset All Difficulties", False, str(e)) def validate_question_action() -> ValidationResult: """Check that asking questions works""" try: env = MedicalDiagnosticEnvironment() env.reset(difficulty="easy") action = DiagnosticAction( action_type="ask_question", question="Does the patient have symptoms?" ) result = env.step(action) assert result is not None assert result.reward >= 0 assert result.done is False # Should not end on question return ValidationResult("Question Action", True) except Exception as e: return ValidationResult("Question Action", False, str(e)) def validate_test_action() -> ValidationResult: """Check that ordering tests works""" try: env = MedicalDiagnosticEnvironment() env.reset(difficulty="easy") action = DiagnosticAction( action_type="order_test", test_name="Complete Blood Count" ) result = env.step(action) assert result is not None assert result.reward >= 0 assert result.done is False # Should not end on test return ValidationResult("Test Action", True) except Exception as e: return ValidationResult("Test Action", False, str(e)) def validate_diagnosis_action() -> ValidationResult: """Check that diagnosis submission works""" try: env = MedicalDiagnosticEnvironment() env.reset(difficulty="easy") action = DiagnosticAction( action_type="submit_diagnosis", diagnosis="Common Flu" ) result = env.step(action) assert result is not None assert result.reward is not None assert result.done is True # Should end on diagnosis return ValidationResult("Diagnosis Action", True) except Exception as e: return ValidationResult("Diagnosis Action", False, str(e)) def validate_episode_summary() -> ValidationResult: """Check that episode summaries are generated correctly""" try: env = MedicalDiagnosticEnvironment() env.reset(difficulty="easy") action = DiagnosticAction( action_type="submit_diagnosis", diagnosis="Test" ) env.step(action) summary = env.get_episode_summary() assert summary is not None assert "case_id" in summary assert "difficulty" in summary assert "accuracy" in summary assert "total_reward" in summary assert "steps" in summary return ValidationResult("Episode Summary", True) except Exception as e: return ValidationResult("Episode Summary", False, str(e)) def validate_reward_functions() -> ValidationResult: """Check that reward functions work""" try: case_id = next(iter(PATIENT_CASES)) q_reward = calculate_question_reward(case_id, "Test question?") assert isinstance(q_reward, float) assert 0.0 <= q_reward <= 1.0 t_reward = calculate_test_reward(case_id, "CBC") assert isinstance(t_reward, float) assert 0.0 <= t_reward <= 1.0 true_diag = PATIENT_CASES[case_id].get("true_diagnosis", "") d_accuracy = calculate_diagnosis_accuracy(case_id, true_diag) assert isinstance(d_accuracy, float) assert 0.0 <= d_accuracy <= 1.0 return ValidationResult("Reward Functions", True) except Exception as e: return ValidationResult("Reward Functions", False, str(e)) def validate_state_property() -> ValidationResult: """Check that state property works""" try: env = MedicalDiagnosticEnvironment() env.reset(difficulty="easy") state = env.state assert state is not None assert hasattr(state, "patient_id") assert hasattr(state, "step_count") assert hasattr(state, "true_diagnosis") assert hasattr(state, "final_accuracy") return ValidationResult("State Property", True) except Exception as e: return ValidationResult("State Property", False, str(e)) def validate_concurrent_support() -> ValidationResult: """Check that environment supports concurrent sessions""" try: env = MedicalDiagnosticEnvironment() assert hasattr(env, "SUPPORTS_CONCURRENT_SESSIONS") assert env.SUPPORTS_CONCURRENT_SESSIONS is True return ValidationResult("Concurrent Sessions Support", True) except Exception as e: return ValidationResult("Concurrent Sessions Support", False, str(e)) def main(): """Run all validation checks""" print("=" * 70) print("MEDICAL DIAGNOSTIC ENVIRONMENT - VALIDATION SUITE") print("=" * 70) print() validators = [ validate_imports, validate_model_creation, validate_environment_init, validate_reset_all_difficulties, validate_question_action, validate_test_action, validate_diagnosis_action, validate_episode_summary, validate_reward_functions, validate_state_property, validate_concurrent_support, ] results: List[ValidationResult] = [] for validator in validators: try: result = validator() except Exception as e: result = ValidationResult( validator.__name__, False, traceback.format_exc() ) results.append(result) print(result) print() print("=" * 70) passed = sum(1 for r in results if r.passed) total = len(results) print(f"SUMMARY: {passed}/{total} checks passed") print("=" * 70) return 0 if passed == total else 1 if __name__ == "__main__": sys.exit(main())