File size: 8,664 Bytes
df98fca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | """Biological rule engine β hard and soft constraint checking.
Hard constraints block action execution entirely.
Soft constraints allow execution but degrade output quality and incur penalties.
"""
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from typing import List
from models import ActionType, ExperimentAction
from server.simulator.latent_state import FullLatentState
class Severity(str, Enum):
HARD = "hard"
SOFT = "soft"
@dataclass
class RuleViolation:
rule_id: str
severity: Severity
message: str
class RuleEngine:
"""Evaluates biological and resource constraints against the current
latent state before each action is applied.
"""
def check(
self, action: ExperimentAction, state: FullLatentState
) -> List[RuleViolation]:
violations: List[RuleViolation] = []
violations.extend(self._check_prerequisites(action, state))
violations.extend(self._check_resource_constraints(action, state))
violations.extend(self._check_redundancy(action, state))
violations.extend(self._check_causal_validity(action, state))
return violations
def hard_violations(self, violations: List[RuleViolation]) -> List[str]:
return [v.message for v in violations if v.severity == Severity.HARD]
def soft_violations(self, violations: List[RuleViolation]) -> List[str]:
return [v.message for v in violations if v.severity == Severity.SOFT]
# ββ prerequisite rules ββββββββββββββββββββββββββββββββββββββββββββββ
def _check_prerequisites(
self, action: ExperimentAction, s: FullLatentState
) -> List[RuleViolation]:
vs: List[RuleViolation] = []
at = action.action_type
p = s.progress
REQUIRES = {
ActionType.PREPARE_LIBRARY: [
("samples_collected", "Cannot prepare library without collected samples"),
],
ActionType.SEQUENCE_CELLS: [
("library_prepared", "Cannot sequence without library preparation"),
],
ActionType.RUN_QC: [
("cells_sequenced", "Cannot run QC before sequencing"),
],
ActionType.FILTER_DATA: [
("qc_performed", "Cannot filter data before QC"),
],
ActionType.NORMALIZE_DATA: [
("data_filtered", "Cannot normalise before filtering"),
],
ActionType.INTEGRATE_BATCHES: [
("data_normalized", "Cannot integrate batches before normalisation"),
],
ActionType.CLUSTER_CELLS: [
("data_normalized", "Cannot cluster before normalisation"),
],
ActionType.DIFFERENTIAL_EXPRESSION: [
("data_normalized", "Cannot run DE before normalisation"),
],
ActionType.TRAJECTORY_ANALYSIS: [
("data_normalized", "Cannot infer trajectories before normalisation"),
],
ActionType.PATHWAY_ENRICHMENT: [
("de_performed", "Cannot run pathway enrichment without DE results"),
],
ActionType.REGULATORY_NETWORK_INFERENCE: [
("data_normalized", "Cannot infer networks before normalisation"),
],
ActionType.MARKER_SELECTION: [
("de_performed", "Cannot select markers without DE results"),
],
ActionType.VALIDATE_MARKER: [
("markers_discovered", "Cannot validate markers before discovery"),
],
ActionType.PERTURB_GENE: [
("samples_collected", "Cannot perturb without samples"),
],
ActionType.PERTURB_COMPOUND: [
("samples_collected", "Cannot perturb without samples"),
],
ActionType.CULTURE_CELLS: [
("samples_collected", "Cannot culture without samples"),
],
}
for flag, msg in REQUIRES.get(at, []):
if not getattr(p, flag, False):
vs.append(RuleViolation(
rule_id=f"prereq_{at.value}_{flag}",
severity=Severity.HARD,
message=msg,
))
return vs
# ββ resource constraints ββββββββββββββββββββββββββββββββββββββββββββ
def _check_resource_constraints(
self, action: ExperimentAction, s: FullLatentState
) -> List[RuleViolation]:
vs: List[RuleViolation] = []
if s.resources.budget_exhausted:
vs.append(RuleViolation(
rule_id="budget_exhausted",
severity=Severity.HARD,
message="Budget exhausted β no further actions possible",
))
if s.resources.time_exhausted:
vs.append(RuleViolation(
rule_id="time_exhausted",
severity=Severity.HARD,
message="Time limit reached β no further actions possible",
))
remaining = s.resources.budget_remaining
from server.simulator.transition import ACTION_COSTS
cost, _ = ACTION_COSTS.get(action.action_type, (0, 0))
if cost > remaining and remaining > 0:
vs.append(RuleViolation(
rule_id="budget_insufficient",
severity=Severity.SOFT,
message=f"Action costs ${cost:,.0f} but only ${remaining:,.0f} remains",
))
return vs
# ββ redundancy checks βββββββββββββββββββββββββββββββββββββββββββββββ
def _check_redundancy(
self, action: ExperimentAction, s: FullLatentState
) -> List[RuleViolation]:
vs: List[RuleViolation] = []
at = action.action_type
p = s.progress
REDUNDANT = {
ActionType.COLLECT_SAMPLE: "samples_collected",
ActionType.PREPARE_LIBRARY: "library_prepared",
ActionType.SEQUENCE_CELLS: "cells_sequenced",
ActionType.RUN_QC: "qc_performed",
ActionType.FILTER_DATA: "data_filtered",
ActionType.NORMALIZE_DATA: "data_normalized",
}
flag = REDUNDANT.get(at)
if flag and getattr(p, flag, False):
vs.append(RuleViolation(
rule_id=f"redundant_{at.value}",
severity=Severity.SOFT,
message=f"Step '{at.value}' already completed β redundant action",
))
return vs
# ββ causal validity βββββββββββββββββββββββββββββββββββββββββββββββββ
def _check_causal_validity(
self, action: ExperimentAction, s: FullLatentState
) -> List[RuleViolation]:
vs: List[RuleViolation] = []
if action.action_type == ActionType.SYNTHESIZE_CONCLUSION:
if not s.progress.de_performed and not s.progress.cells_clustered:
vs.append(RuleViolation(
rule_id="premature_conclusion",
severity=Severity.SOFT,
message="Synthesising conclusion without substantive analysis",
))
claims = action.parameters.get("claims", [])
for claim in claims:
if isinstance(claim, dict) and claim.get("claim_type") == "causal":
if not s.progress.markers_validated and not s.progress.networks_inferred:
vs.append(RuleViolation(
rule_id="unsupported_causal_claim",
severity=Severity.SOFT,
message="Causal claim without validation or network evidence",
))
break
if action.action_type == ActionType.PATHWAY_ENRICHMENT:
if not s.progress.de_performed:
vs.append(RuleViolation(
rule_id="pathway_without_de",
severity=Severity.SOFT,
message="Pathway enrichment without DE may yield unreliable results",
))
return vs
|