File size: 8,664 Bytes
df98fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
"""Biological rule engine β€” hard and soft constraint checking.



Hard constraints block action execution entirely.

Soft constraints allow execution but degrade output quality and incur penalties.

"""

from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
from typing import List

from models import ActionType, ExperimentAction

from server.simulator.latent_state import FullLatentState


class Severity(str, Enum):
    HARD = "hard"
    SOFT = "soft"


@dataclass
class RuleViolation:
    rule_id: str
    severity: Severity
    message: str


class RuleEngine:
    """Evaluates biological and resource constraints against the current

    latent state before each action is applied.

    """

    def check(

        self, action: ExperimentAction, state: FullLatentState

    ) -> List[RuleViolation]:
        violations: List[RuleViolation] = []
        violations.extend(self._check_prerequisites(action, state))
        violations.extend(self._check_resource_constraints(action, state))
        violations.extend(self._check_redundancy(action, state))
        violations.extend(self._check_causal_validity(action, state))
        return violations

    def hard_violations(self, violations: List[RuleViolation]) -> List[str]:
        return [v.message for v in violations if v.severity == Severity.HARD]

    def soft_violations(self, violations: List[RuleViolation]) -> List[str]:
        return [v.message for v in violations if v.severity == Severity.SOFT]

    # ── prerequisite rules ──────────────────────────────────────────────

    def _check_prerequisites(

        self, action: ExperimentAction, s: FullLatentState

    ) -> List[RuleViolation]:
        vs: List[RuleViolation] = []
        at = action.action_type
        p = s.progress

        REQUIRES = {
            ActionType.PREPARE_LIBRARY: [
                ("samples_collected", "Cannot prepare library without collected samples"),
            ],
            ActionType.SEQUENCE_CELLS: [
                ("library_prepared", "Cannot sequence without library preparation"),
            ],
            ActionType.RUN_QC: [
                ("cells_sequenced", "Cannot run QC before sequencing"),
            ],
            ActionType.FILTER_DATA: [
                ("qc_performed", "Cannot filter data before QC"),
            ],
            ActionType.NORMALIZE_DATA: [
                ("data_filtered", "Cannot normalise before filtering"),
            ],
            ActionType.INTEGRATE_BATCHES: [
                ("data_normalized", "Cannot integrate batches before normalisation"),
            ],
            ActionType.CLUSTER_CELLS: [
                ("data_normalized", "Cannot cluster before normalisation"),
            ],
            ActionType.DIFFERENTIAL_EXPRESSION: [
                ("data_normalized", "Cannot run DE before normalisation"),
            ],
            ActionType.TRAJECTORY_ANALYSIS: [
                ("data_normalized", "Cannot infer trajectories before normalisation"),
            ],
            ActionType.PATHWAY_ENRICHMENT: [
                ("de_performed", "Cannot run pathway enrichment without DE results"),
            ],
            ActionType.REGULATORY_NETWORK_INFERENCE: [
                ("data_normalized", "Cannot infer networks before normalisation"),
            ],
            ActionType.MARKER_SELECTION: [
                ("de_performed", "Cannot select markers without DE results"),
            ],
            ActionType.VALIDATE_MARKER: [
                ("markers_discovered", "Cannot validate markers before discovery"),
            ],
            ActionType.PERTURB_GENE: [
                ("samples_collected", "Cannot perturb without samples"),
            ],
            ActionType.PERTURB_COMPOUND: [
                ("samples_collected", "Cannot perturb without samples"),
            ],
            ActionType.CULTURE_CELLS: [
                ("samples_collected", "Cannot culture without samples"),
            ],
        }

        for flag, msg in REQUIRES.get(at, []):
            if not getattr(p, flag, False):
                vs.append(RuleViolation(
                    rule_id=f"prereq_{at.value}_{flag}",
                    severity=Severity.HARD,
                    message=msg,
                ))
        return vs

    # ── resource constraints ────────────────────────────────────────────

    def _check_resource_constraints(

        self, action: ExperimentAction, s: FullLatentState

    ) -> List[RuleViolation]:
        vs: List[RuleViolation] = []
        if s.resources.budget_exhausted:
            vs.append(RuleViolation(
                rule_id="budget_exhausted",
                severity=Severity.HARD,
                message="Budget exhausted β€” no further actions possible",
            ))
        if s.resources.time_exhausted:
            vs.append(RuleViolation(
                rule_id="time_exhausted",
                severity=Severity.HARD,
                message="Time limit reached β€” no further actions possible",
            ))

        remaining = s.resources.budget_remaining
        from server.simulator.transition import ACTION_COSTS
        cost, _ = ACTION_COSTS.get(action.action_type, (0, 0))
        if cost > remaining and remaining > 0:
            vs.append(RuleViolation(
                rule_id="budget_insufficient",
                severity=Severity.SOFT,
                message=f"Action costs ${cost:,.0f} but only ${remaining:,.0f} remains",
            ))
        return vs

    # ── redundancy checks ───────────────────────────────────────────────

    def _check_redundancy(

        self, action: ExperimentAction, s: FullLatentState

    ) -> List[RuleViolation]:
        vs: List[RuleViolation] = []
        at = action.action_type
        p = s.progress

        REDUNDANT = {
            ActionType.COLLECT_SAMPLE: "samples_collected",
            ActionType.PREPARE_LIBRARY: "library_prepared",
            ActionType.SEQUENCE_CELLS: "cells_sequenced",
            ActionType.RUN_QC: "qc_performed",
            ActionType.FILTER_DATA: "data_filtered",
            ActionType.NORMALIZE_DATA: "data_normalized",
        }
        flag = REDUNDANT.get(at)
        if flag and getattr(p, flag, False):
            vs.append(RuleViolation(
                rule_id=f"redundant_{at.value}",
                severity=Severity.SOFT,
                message=f"Step '{at.value}' already completed β€” redundant action",
            ))
        return vs

    # ── causal validity ─────────────────────────────────────────────────

    def _check_causal_validity(

        self, action: ExperimentAction, s: FullLatentState

    ) -> List[RuleViolation]:
        vs: List[RuleViolation] = []
        if action.action_type == ActionType.SYNTHESIZE_CONCLUSION:
            if not s.progress.de_performed and not s.progress.cells_clustered:
                vs.append(RuleViolation(
                    rule_id="premature_conclusion",
                    severity=Severity.SOFT,
                    message="Synthesising conclusion without substantive analysis",
                ))

            claims = action.parameters.get("claims", [])
            for claim in claims:
                if isinstance(claim, dict) and claim.get("claim_type") == "causal":
                    if not s.progress.markers_validated and not s.progress.networks_inferred:
                        vs.append(RuleViolation(
                            rule_id="unsupported_causal_claim",
                            severity=Severity.SOFT,
                            message="Causal claim without validation or network evidence",
                        ))
                        break

        if action.action_type == ActionType.PATHWAY_ENRICHMENT:
            if not s.progress.de_performed:
                vs.append(RuleViolation(
                    rule_id="pathway_without_de",
                    severity=Severity.SOFT,
                    message="Pathway enrichment without DE may yield unreliable results",
                ))
        return vs