synapse-openenv / validate.py
vicky0406's picture
Update validate.py
6ee6823 verified
#!/usr/bin/env python3
"""
Quick validation script for Medical Diagnostic Environment
This script validates that the core environment works correctly without
requiring the server to be running or external dependencies beyond models.
Run with: python validate.py
"""
import sys
import traceback
from pathlib import Path
from typing import Dict, List
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent))
from models import DiagnosticAction, PatientObservation, ClinicalState
from server.environment import MedicalDiagnosticEnvironment
from server.medical_data import (
PATIENT_CASES,
calculate_question_reward,
calculate_test_reward,
calculate_diagnosis_accuracy,
)
class ValidationResult:
"""Result of a validation check"""
def __init__(self, name: str, passed: bool, error: str = None):
self.name = name
self.passed = passed
self.error = error
def __str__(self):
status = "PASS" if self.passed else "FAIL"
msg = f"{status}: {self.name}"
if self.error:
msg += f"\n Error: {self.error}"
return msg
def validate_imports() -> ValidationResult:
"""Check that all imports work"""
try:
from models import DiagnosticAction, PatientObservation, ClinicalState
from server.environment import MedicalDiagnosticEnvironment
from server.medical_data import (
calculate_question_reward,
calculate_test_reward,
calculate_diagnosis_accuracy,
)
return ValidationResult("Imports", True)
except Exception as e:
return ValidationResult("Imports", False, str(e))
def validate_model_creation() -> ValidationResult:
"""Check that models can be instantiated"""
try:
action = DiagnosticAction(
action_type="ask_question",
question="Test question?"
)
assert action.action_type == "ask_question"
assert action.question == "Test question?"
return ValidationResult("Model Creation", True)
except Exception as e:
return ValidationResult("Model Creation", False, str(e))
def validate_environment_init() -> ValidationResult:
"""Check that environment initializes"""
try:
env = MedicalDiagnosticEnvironment()
assert env is not None
assert hasattr(env, "reset")
assert hasattr(env, "step")
return ValidationResult("Environment Initialization", True)
except Exception as e:
return ValidationResult("Environment Initialization", False, str(e))
def validate_reset_all_difficulties() -> ValidationResult:
"""Check that reset works for all difficulties"""
try:
env = MedicalDiagnosticEnvironment()
for difficulty in ["easy", "medium", "hard"]:
obs = env.reset(difficulty=difficulty)
assert obs is not None
assert env.current_difficulty == difficulty
assert env.current_case_id is not None
return ValidationResult("Reset All Difficulties", True)
except Exception as e:
return ValidationResult("Reset All Difficulties", False, str(e))
def validate_question_action() -> ValidationResult:
"""Check that asking questions works"""
try:
env = MedicalDiagnosticEnvironment()
env.reset(difficulty="easy")
action = DiagnosticAction(
action_type="ask_question",
question="Does the patient have symptoms?"
)
result = env.step(action)
assert result is not None
assert result.reward >= 0
assert result.done is False # Should not end on question
return ValidationResult("Question Action", True)
except Exception as e:
return ValidationResult("Question Action", False, str(e))
def validate_test_action() -> ValidationResult:
"""Check that ordering tests works"""
try:
env = MedicalDiagnosticEnvironment()
env.reset(difficulty="easy")
action = DiagnosticAction(
action_type="order_test",
test_name="Complete Blood Count"
)
result = env.step(action)
assert result is not None
assert result.reward >= 0
assert result.done is False # Should not end on test
return ValidationResult("Test Action", True)
except Exception as e:
return ValidationResult("Test Action", False, str(e))
def validate_diagnosis_action() -> ValidationResult:
"""Check that diagnosis submission works"""
try:
env = MedicalDiagnosticEnvironment()
env.reset(difficulty="easy")
action = DiagnosticAction(
action_type="submit_diagnosis",
diagnosis="Common Flu"
)
result = env.step(action)
assert result is not None
assert result.reward is not None
assert result.done is True # Should end on diagnosis
return ValidationResult("Diagnosis Action", True)
except Exception as e:
return ValidationResult("Diagnosis Action", False, str(e))
def validate_episode_summary() -> ValidationResult:
"""Check that episode summaries are generated correctly"""
try:
env = MedicalDiagnosticEnvironment()
env.reset(difficulty="easy")
action = DiagnosticAction(
action_type="submit_diagnosis",
diagnosis="Test"
)
env.step(action)
summary = env.get_episode_summary()
assert summary is not None
assert "case_id" in summary
assert "difficulty" in summary
assert "accuracy" in summary
assert "total_reward" in summary
assert "steps" in summary
return ValidationResult("Episode Summary", True)
except Exception as e:
return ValidationResult("Episode Summary", False, str(e))
def validate_reward_functions() -> ValidationResult:
"""Check that reward functions work"""
try:
case_id = next(iter(PATIENT_CASES))
q_reward = calculate_question_reward(case_id, "Test question?")
assert isinstance(q_reward, float)
assert 0.0 <= q_reward <= 1.0
t_reward = calculate_test_reward(case_id, "CBC")
assert isinstance(t_reward, float)
assert 0.0 <= t_reward <= 1.0
true_diag = PATIENT_CASES[case_id].get("true_diagnosis", "")
d_accuracy = calculate_diagnosis_accuracy(case_id, true_diag)
assert isinstance(d_accuracy, float)
assert 0.0 <= d_accuracy <= 1.0
return ValidationResult("Reward Functions", True)
except Exception as e:
return ValidationResult("Reward Functions", False, str(e))
def validate_state_property() -> ValidationResult:
"""Check that state property works"""
try:
env = MedicalDiagnosticEnvironment()
env.reset(difficulty="easy")
state = env.state
assert state is not None
assert hasattr(state, "patient_id")
assert hasattr(state, "step_count")
assert hasattr(state, "true_diagnosis")
assert hasattr(state, "final_accuracy")
return ValidationResult("State Property", True)
except Exception as e:
return ValidationResult("State Property", False, str(e))
def validate_concurrent_support() -> ValidationResult:
"""Check that environment supports concurrent sessions"""
try:
env = MedicalDiagnosticEnvironment()
assert hasattr(env, "SUPPORTS_CONCURRENT_SESSIONS")
assert env.SUPPORTS_CONCURRENT_SESSIONS is True
return ValidationResult("Concurrent Sessions Support", True)
except Exception as e:
return ValidationResult("Concurrent Sessions Support", False, str(e))
def main():
"""Run all validation checks"""
print("=" * 70)
print("MEDICAL DIAGNOSTIC ENVIRONMENT - VALIDATION SUITE")
print("=" * 70)
print()
validators = [
validate_imports,
validate_model_creation,
validate_environment_init,
validate_reset_all_difficulties,
validate_question_action,
validate_test_action,
validate_diagnosis_action,
validate_episode_summary,
validate_reward_functions,
validate_state_property,
validate_concurrent_support,
]
results: List[ValidationResult] = []
for validator in validators:
try:
result = validator()
except Exception as e:
result = ValidationResult(
validator.__name__,
False,
traceback.format_exc()
)
results.append(result)
print(result)
print()
print("=" * 70)
passed = sum(1 for r in results if r.passed)
total = len(results)
print(f"SUMMARY: {passed}/{total} checks passed")
print("=" * 70)
return 0 if passed == total else 1
if __name__ == "__main__":
sys.exit(main())