Bio-EnvRL / server /rules /engine.py
Ev3Dev's picture
Upload folder using huggingface_hub
df98fca verified
"""Biological rule engine β€” hard and soft constraint checking.
Hard constraints block action execution entirely.
Soft constraints allow execution but degrade output quality and incur penalties.
"""
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from typing import List
from models import ActionType, ExperimentAction
from server.simulator.latent_state import FullLatentState
class Severity(str, Enum):
HARD = "hard"
SOFT = "soft"
@dataclass
class RuleViolation:
rule_id: str
severity: Severity
message: str
class RuleEngine:
"""Evaluates biological and resource constraints against the current
latent state before each action is applied.
"""
def check(
self, action: ExperimentAction, state: FullLatentState
) -> List[RuleViolation]:
violations: List[RuleViolation] = []
violations.extend(self._check_prerequisites(action, state))
violations.extend(self._check_resource_constraints(action, state))
violations.extend(self._check_redundancy(action, state))
violations.extend(self._check_causal_validity(action, state))
return violations
def hard_violations(self, violations: List[RuleViolation]) -> List[str]:
return [v.message for v in violations if v.severity == Severity.HARD]
def soft_violations(self, violations: List[RuleViolation]) -> List[str]:
return [v.message for v in violations if v.severity == Severity.SOFT]
# ── prerequisite rules ──────────────────────────────────────────────
def _check_prerequisites(
self, action: ExperimentAction, s: FullLatentState
) -> List[RuleViolation]:
vs: List[RuleViolation] = []
at = action.action_type
p = s.progress
REQUIRES = {
ActionType.PREPARE_LIBRARY: [
("samples_collected", "Cannot prepare library without collected samples"),
],
ActionType.SEQUENCE_CELLS: [
("library_prepared", "Cannot sequence without library preparation"),
],
ActionType.RUN_QC: [
("cells_sequenced", "Cannot run QC before sequencing"),
],
ActionType.FILTER_DATA: [
("qc_performed", "Cannot filter data before QC"),
],
ActionType.NORMALIZE_DATA: [
("data_filtered", "Cannot normalise before filtering"),
],
ActionType.INTEGRATE_BATCHES: [
("data_normalized", "Cannot integrate batches before normalisation"),
],
ActionType.CLUSTER_CELLS: [
("data_normalized", "Cannot cluster before normalisation"),
],
ActionType.DIFFERENTIAL_EXPRESSION: [
("data_normalized", "Cannot run DE before normalisation"),
],
ActionType.TRAJECTORY_ANALYSIS: [
("data_normalized", "Cannot infer trajectories before normalisation"),
],
ActionType.PATHWAY_ENRICHMENT: [
("de_performed", "Cannot run pathway enrichment without DE results"),
],
ActionType.REGULATORY_NETWORK_INFERENCE: [
("data_normalized", "Cannot infer networks before normalisation"),
],
ActionType.MARKER_SELECTION: [
("de_performed", "Cannot select markers without DE results"),
],
ActionType.VALIDATE_MARKER: [
("markers_discovered", "Cannot validate markers before discovery"),
],
ActionType.PERTURB_GENE: [
("samples_collected", "Cannot perturb without samples"),
],
ActionType.PERTURB_COMPOUND: [
("samples_collected", "Cannot perturb without samples"),
],
ActionType.CULTURE_CELLS: [
("samples_collected", "Cannot culture without samples"),
],
}
for flag, msg in REQUIRES.get(at, []):
if not getattr(p, flag, False):
vs.append(RuleViolation(
rule_id=f"prereq_{at.value}_{flag}",
severity=Severity.HARD,
message=msg,
))
return vs
# ── resource constraints ────────────────────────────────────────────
def _check_resource_constraints(
self, action: ExperimentAction, s: FullLatentState
) -> List[RuleViolation]:
vs: List[RuleViolation] = []
if s.resources.budget_exhausted:
vs.append(RuleViolation(
rule_id="budget_exhausted",
severity=Severity.HARD,
message="Budget exhausted β€” no further actions possible",
))
if s.resources.time_exhausted:
vs.append(RuleViolation(
rule_id="time_exhausted",
severity=Severity.HARD,
message="Time limit reached β€” no further actions possible",
))
remaining = s.resources.budget_remaining
from server.simulator.transition import ACTION_COSTS
cost, _ = ACTION_COSTS.get(action.action_type, (0, 0))
if cost > remaining and remaining > 0:
vs.append(RuleViolation(
rule_id="budget_insufficient",
severity=Severity.SOFT,
message=f"Action costs ${cost:,.0f} but only ${remaining:,.0f} remains",
))
return vs
# ── redundancy checks ───────────────────────────────────────────────
def _check_redundancy(
self, action: ExperimentAction, s: FullLatentState
) -> List[RuleViolation]:
vs: List[RuleViolation] = []
at = action.action_type
p = s.progress
REDUNDANT = {
ActionType.COLLECT_SAMPLE: "samples_collected",
ActionType.PREPARE_LIBRARY: "library_prepared",
ActionType.SEQUENCE_CELLS: "cells_sequenced",
ActionType.RUN_QC: "qc_performed",
ActionType.FILTER_DATA: "data_filtered",
ActionType.NORMALIZE_DATA: "data_normalized",
}
flag = REDUNDANT.get(at)
if flag and getattr(p, flag, False):
vs.append(RuleViolation(
rule_id=f"redundant_{at.value}",
severity=Severity.SOFT,
message=f"Step '{at.value}' already completed β€” redundant action",
))
return vs
# ── causal validity ─────────────────────────────────────────────────
def _check_causal_validity(
self, action: ExperimentAction, s: FullLatentState
) -> List[RuleViolation]:
vs: List[RuleViolation] = []
if action.action_type == ActionType.SYNTHESIZE_CONCLUSION:
if not s.progress.de_performed and not s.progress.cells_clustered:
vs.append(RuleViolation(
rule_id="premature_conclusion",
severity=Severity.SOFT,
message="Synthesising conclusion without substantive analysis",
))
claims = action.parameters.get("claims", [])
for claim in claims:
if isinstance(claim, dict) and claim.get("claim_type") == "causal":
if not s.progress.markers_validated and not s.progress.networks_inferred:
vs.append(RuleViolation(
rule_id="unsupported_causal_claim",
severity=Severity.SOFT,
message="Causal claim without validation or network evidence",
))
break
if action.action_type == ActionType.PATHWAY_ENRICHMENT:
if not s.progress.de_performed:
vs.append(RuleViolation(
rule_id="pathway_without_de",
severity=Severity.SOFT,
message="Pathway enrichment without DE may yield unreliable results",
))
return vs