"""Biological rule engine — hard and soft constraint checking. Hard constraints block action execution entirely. Soft constraints allow execution but degrade output quality and incur penalties. """ from __future__ import annotations from dataclasses import dataclass from enum import Enum from typing import List from models import ActionType, ExperimentAction from server.simulator.latent_state import FullLatentState class Severity(str, Enum): HARD = "hard" SOFT = "soft" @dataclass class RuleViolation: rule_id: str severity: Severity message: str class RuleEngine: """Evaluates biological and resource constraints against the current latent state before each action is applied. """ def check( self, action: ExperimentAction, state: FullLatentState ) -> List[RuleViolation]: violations: List[RuleViolation] = [] violations.extend(self._check_prerequisites(action, state)) violations.extend(self._check_resource_constraints(action, state)) violations.extend(self._check_redundancy(action, state)) violations.extend(self._check_causal_validity(action, state)) return violations def hard_violations(self, violations: List[RuleViolation]) -> List[str]: return [v.message for v in violations if v.severity == Severity.HARD] def soft_violations(self, violations: List[RuleViolation]) -> List[str]: return [v.message for v in violations if v.severity == Severity.SOFT] # ── prerequisite rules ────────────────────────────────────────────── def _check_prerequisites( self, action: ExperimentAction, s: FullLatentState ) -> List[RuleViolation]: vs: List[RuleViolation] = [] at = action.action_type p = s.progress REQUIRES = { ActionType.PREPARE_LIBRARY: [ ("samples_collected", "Cannot prepare library without collected samples"), ], ActionType.SEQUENCE_CELLS: [ ("library_prepared", "Cannot sequence without library preparation"), ], ActionType.RUN_QC: [ ("cells_sequenced", "Cannot run QC before sequencing"), ], ActionType.FILTER_DATA: [ ("qc_performed", "Cannot filter data before QC"), ], ActionType.NORMALIZE_DATA: [ ("data_filtered", "Cannot normalise before filtering"), ], ActionType.INTEGRATE_BATCHES: [ ("data_normalized", "Cannot integrate batches before normalisation"), ], ActionType.CLUSTER_CELLS: [ ("data_normalized", "Cannot cluster before normalisation"), ], ActionType.DIFFERENTIAL_EXPRESSION: [ ("data_normalized", "Cannot run DE before normalisation"), ], ActionType.TRAJECTORY_ANALYSIS: [ ("data_normalized", "Cannot infer trajectories before normalisation"), ], ActionType.PATHWAY_ENRICHMENT: [ ("de_performed", "Cannot run pathway enrichment without DE results"), ], ActionType.REGULATORY_NETWORK_INFERENCE: [ ("data_normalized", "Cannot infer networks before normalisation"), ], ActionType.MARKER_SELECTION: [ ("de_performed", "Cannot select markers without DE results"), ], ActionType.VALIDATE_MARKER: [ ("markers_discovered", "Cannot validate markers before discovery"), ], ActionType.PERTURB_GENE: [ ("samples_collected", "Cannot perturb without samples"), ], ActionType.PERTURB_COMPOUND: [ ("samples_collected", "Cannot perturb without samples"), ], ActionType.CULTURE_CELLS: [ ("samples_collected", "Cannot culture without samples"), ], } for flag, msg in REQUIRES.get(at, []): if not getattr(p, flag, False): vs.append(RuleViolation( rule_id=f"prereq_{at.value}_{flag}", severity=Severity.HARD, message=msg, )) return vs # ── resource constraints ──────────────────────────────────────────── def _check_resource_constraints( self, action: ExperimentAction, s: FullLatentState ) -> List[RuleViolation]: vs: List[RuleViolation] = [] if s.resources.budget_exhausted: vs.append(RuleViolation( rule_id="budget_exhausted", severity=Severity.HARD, message="Budget exhausted — no further actions possible", )) if s.resources.time_exhausted: vs.append(RuleViolation( rule_id="time_exhausted", severity=Severity.HARD, message="Time limit reached — no further actions possible", )) remaining = s.resources.budget_remaining from server.simulator.transition import ACTION_COSTS cost, _ = ACTION_COSTS.get(action.action_type, (0, 0)) if cost > remaining and remaining > 0: vs.append(RuleViolation( rule_id="budget_insufficient", severity=Severity.SOFT, message=f"Action costs ${cost:,.0f} but only ${remaining:,.0f} remains", )) return vs # ── redundancy checks ─────────────────────────────────────────────── def _check_redundancy( self, action: ExperimentAction, s: FullLatentState ) -> List[RuleViolation]: vs: List[RuleViolation] = [] at = action.action_type p = s.progress REDUNDANT = { ActionType.COLLECT_SAMPLE: "samples_collected", ActionType.PREPARE_LIBRARY: "library_prepared", ActionType.SEQUENCE_CELLS: "cells_sequenced", ActionType.RUN_QC: "qc_performed", ActionType.FILTER_DATA: "data_filtered", ActionType.NORMALIZE_DATA: "data_normalized", } flag = REDUNDANT.get(at) if flag and getattr(p, flag, False): vs.append(RuleViolation( rule_id=f"redundant_{at.value}", severity=Severity.SOFT, message=f"Step '{at.value}' already completed — redundant action", )) return vs # ── causal validity ───────────────────────────────────────────────── def _check_causal_validity( self, action: ExperimentAction, s: FullLatentState ) -> List[RuleViolation]: vs: List[RuleViolation] = [] if action.action_type == ActionType.SYNTHESIZE_CONCLUSION: if not s.progress.de_performed and not s.progress.cells_clustered: vs.append(RuleViolation( rule_id="premature_conclusion", severity=Severity.SOFT, message="Synthesising conclusion without substantive analysis", )) claims = action.parameters.get("claims", []) for claim in claims: if isinstance(claim, dict) and claim.get("claim_type") == "causal": if not s.progress.markers_validated and not s.progress.networks_inferred: vs.append(RuleViolation( rule_id="unsupported_causal_claim", severity=Severity.SOFT, message="Causal claim without validation or network evidence", )) break if action.action_type == ActionType.PATHWAY_ENRICHMENT: if not s.progress.de_performed: vs.append(RuleViolation( rule_id="pathway_without_de", severity=Severity.SOFT, message="Pathway enrichment without DE may yield unreliable results", )) return vs