| """Procedural scenario generator.
|
|
|
| Composes biologically coherent ``Scenario`` objects from the curated
|
| palette in ``bio_palette``, producing fully populated
|
| ``LatentBiologicalState`` instances that drive every simulator tool
|
| (clustering, DE, pathway enrichment, trajectory, regulatory networks,
|
| marker validation) with realistic intermediate outputs.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import logging
|
| from typing import Any, Dict, List, Optional, Tuple
|
|
|
| import numpy as np
|
|
|
| from models import TaskSpec
|
|
|
| from server.simulator.latent_state import (
|
| CellPopulation,
|
| LatentBiologicalState,
|
| TechnicalState,
|
| )
|
|
|
| from .bio_palette import (
|
| DISEASE_PROFILES,
|
| HIDDEN_FAILURE_TEMPLATES,
|
| PATHWAY_LIBRARY,
|
| PERTURBATION_TEMPLATES,
|
| REGULATORY_TEMPLATES,
|
| TISSUE_CELL_TYPES,
|
| TRAJECTORY_TEMPLATES,
|
| CellTypeTemplate,
|
| DiseaseProfile,
|
| )
|
| from .scenarios import Scenario
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
| SCENARIO_TYPES = ("de", "trajectory", "perturbation", "biomarker")
|
|
|
| _DIFFICULTY_PARAMS = {
|
| "easy": {
|
| "n_pops": (4, 5),
|
| "de_scale": (1.2, 1.6),
|
| "noise_dropout": (0.05, 0.10),
|
| "noise_doublet": (0.03, 0.06),
|
| "noise_ambient": (0.02, 0.05),
|
| "noise_batch_strength": (0.05, 0.12),
|
| "n_batches": (1, 2),
|
| "budget_range": (70_000, 100_000),
|
| "time_range": (100, 150),
|
| "sample_quality": (0.85, 0.95),
|
| "include_trajectory": False,
|
| "include_perturbation": False,
|
| "include_network": False,
|
| "include_failure_conditions": False,
|
| },
|
| "medium": {
|
| "n_pops": (5, 7),
|
| "de_scale": (0.9, 1.3),
|
| "noise_dropout": (0.08, 0.14),
|
| "noise_doublet": (0.04, 0.08),
|
| "noise_ambient": (0.03, 0.07),
|
| "noise_batch_strength": (0.08, 0.18),
|
| "n_batches": (1, 3),
|
| "budget_range": (80_000, 120_000),
|
| "time_range": (120, 180),
|
| "sample_quality": (0.78, 0.92),
|
| "include_trajectory": True,
|
| "include_perturbation": False,
|
| "include_network": True,
|
| "include_failure_conditions": False,
|
| },
|
| "hard": {
|
| "n_pops": (6, 8),
|
| "de_scale": (0.6, 1.0),
|
| "noise_dropout": (0.10, 0.20),
|
| "noise_doublet": (0.06, 0.12),
|
| "noise_ambient": (0.05, 0.10),
|
| "noise_batch_strength": (0.12, 0.25),
|
| "n_batches": (2, 4),
|
| "budget_range": (90_000, 140_000),
|
| "time_range": (140, 200),
|
| "sample_quality": (0.65, 0.85),
|
| "include_trajectory": True,
|
| "include_perturbation": True,
|
| "include_network": True,
|
| "include_failure_conditions": True,
|
| },
|
| }
|
|
|
|
|
| def generate_scenario(
|
| seed: int,
|
| difficulty: str = "medium",
|
| scenario_type: Optional[str] = None,
|
| ) -> Scenario:
|
| """Generate a single procedural scenario with complete latent state.
|
|
|
| Parameters
|
| ----------
|
| seed
|
| RNG seed for reproducibility.
|
| difficulty
|
| One of ``"easy"``, ``"medium"``, ``"hard"``.
|
| scenario_type
|
| One of ``"de"``, ``"trajectory"``, ``"perturbation"``,
|
| ``"biomarker"``, or ``None`` for random selection.
|
| """
|
| rng = np.random.default_rng(seed)
|
| params = _DIFFICULTY_PARAMS[difficulty]
|
|
|
| if scenario_type is None:
|
| scenario_type = rng.choice(SCENARIO_TYPES)
|
|
|
| disease_key = rng.choice(list(DISEASE_PROFILES.keys()))
|
| disease = DISEASE_PROFILES[disease_key]
|
| tissue = disease.tissue
|
|
|
| cell_templates = TISSUE_CELL_TYPES.get(tissue, [])
|
| if not cell_templates:
|
| tissue = rng.choice(list(TISSUE_CELL_TYPES.keys()))
|
| cell_templates = TISSUE_CELL_TYPES[tissue]
|
|
|
| populations = _sample_populations(rng, cell_templates, disease, params)
|
| de_genes = _build_de_genes(rng, disease, params)
|
| pathways = _build_pathways(rng, disease)
|
| markers = _derive_markers(rng, de_genes, disease)
|
| mechanisms = list(disease.mechanism_templates)
|
| n_cells = int(rng.integers(8_000, 22_000))
|
|
|
| trajectory = None
|
| if scenario_type == "trajectory" or (
|
| params["include_trajectory"] and rng.random() < 0.4
|
| ):
|
| trajectory = _build_trajectory(rng, tissue, populations)
|
|
|
| reg_network: Dict[str, List[str]] = {}
|
| if scenario_type == "trajectory" or (
|
| params["include_network"] and rng.random() < 0.5
|
| ):
|
| reg_network = _build_regulatory_network(rng, tissue, populations)
|
|
|
| perturbation_effects: Dict[str, Dict[str, float]] = {}
|
| if scenario_type == "perturbation" or (
|
| params["include_perturbation"] and rng.random() < 0.5
|
| ):
|
| perturbation_effects = _build_perturbation(rng, disease)
|
|
|
| technical = _build_technical(rng, params)
|
|
|
| hidden_failures: List[str] = []
|
| if params["include_failure_conditions"] and rng.random() < 0.6:
|
| n_failures = int(rng.integers(1, 3))
|
| indices = rng.choice(
|
| len(HIDDEN_FAILURE_TEMPLATES), size=min(n_failures, len(HIDDEN_FAILURE_TEMPLATES)), replace=False,
|
| )
|
| hidden_failures = [HIDDEN_FAILURE_TEMPLATES[i] for i in indices]
|
|
|
| task = _build_task(rng, disease, tissue, scenario_type, params, perturbation_effects)
|
|
|
| biology = LatentBiologicalState(
|
| cell_populations=populations,
|
| true_de_genes=de_genes,
|
| true_pathways=pathways,
|
| true_trajectory=trajectory,
|
| true_regulatory_network=reg_network,
|
| perturbation_effects=perturbation_effects,
|
| true_markers=markers,
|
| causal_mechanisms=mechanisms,
|
| n_true_cells=n_cells,
|
| )
|
|
|
| name = f"proc_{disease.name}_{scenario_type}_{seed}"
|
|
|
| tags = [scenario_type, "scRNA-seq", tissue, disease.name, difficulty]
|
|
|
| return Scenario(
|
| name=name,
|
| task=task,
|
| biology=biology,
|
| technical=technical,
|
| hidden_failure_conditions=hidden_failures,
|
| difficulty=difficulty,
|
| tags=tags,
|
| )
|
|
|
|
|
| def generate_procedural_scenarios(
|
| n: int = 20,
|
| seed: int = 42,
|
| ) -> List[Scenario]:
|
| """Pre-generate a pool of procedural scenarios across difficulties."""
|
| rng = np.random.default_rng(seed)
|
| scenarios: List[Scenario] = []
|
| difficulties = ["easy", "medium", "hard"]
|
|
|
| for i in range(n):
|
| diff = difficulties[i % len(difficulties)]
|
| child_seed = int(rng.integers(0, 2**31))
|
| scenario = generate_scenario(
|
| seed=child_seed,
|
| difficulty=diff,
|
| scenario_type=None,
|
| )
|
| scenarios.append(scenario)
|
|
|
| logger.info("Generated %d procedural scenarios.", len(scenarios))
|
| return scenarios
|
|
|
|
|
|
|
|
|
|
|
| def _sample_populations(
|
| rng: np.random.Generator,
|
| templates: List[CellTypeTemplate],
|
| disease: DiseaseProfile,
|
| params: dict,
|
| ) -> List[CellPopulation]:
|
| lo, hi = params["n_pops"]
|
| n_pops = int(rng.integers(lo, hi + 1))
|
| n_pops = min(n_pops, len(templates))
|
|
|
| indices = rng.choice(len(templates), size=n_pops, replace=False)
|
| selected = [templates[i] for i in sorted(indices)]
|
|
|
| responding_names = set(disease.responding_cell_types)
|
|
|
| populations: List[CellPopulation] = []
|
| for tmpl in selected:
|
| prop = float(rng.uniform(*tmpl.proportion_range))
|
| state = rng.choice(tmpl.states)
|
|
|
| condition_response: Dict[str, float] = {}
|
| if tmpl.disease_responsive and tmpl.name in responding_names:
|
| condition_response[disease.condition_name] = float(
|
| rng.uniform(*tmpl.response_range)
|
| )
|
|
|
| populations.append(CellPopulation(
|
| name=tmpl.name,
|
| proportion=prop,
|
| marker_genes=list(tmpl.marker_genes),
|
| state=state,
|
| condition_response=condition_response,
|
| ))
|
|
|
| total = sum(p.proportion for p in populations)
|
| if total > 0:
|
| for p in populations:
|
| p.proportion = round(p.proportion / total, 4)
|
|
|
| return populations
|
|
|
|
|
| def _build_de_genes(
|
| rng: np.random.Generator,
|
| disease: DiseaseProfile,
|
| params: dict,
|
| ) -> Dict[str, Dict[str, float]]:
|
| comparison = f"{disease.condition_name}_vs_healthy"
|
| scale_lo, scale_hi = params["de_scale"]
|
|
|
| effects: Dict[str, float] = {}
|
| for gene, (lo, hi) in disease.de_genes.items():
|
| base = float(rng.uniform(lo, hi))
|
| scale = float(rng.uniform(scale_lo, scale_hi))
|
| if base > 0:
|
| effects[gene] = round(base * scale, 3)
|
| else:
|
| effects[gene] = round(base * scale, 3)
|
|
|
| return {comparison: effects}
|
|
|
|
|
| def _build_pathways(
|
| rng: np.random.Generator,
|
| disease: DiseaseProfile,
|
| ) -> Dict[str, float]:
|
| pathways: Dict[str, float] = {}
|
| for pw, (lo, hi) in disease.pathways.items():
|
| pathways[pw] = round(float(rng.uniform(lo, hi)), 3)
|
| return pathways
|
|
|
|
|
| def _derive_markers(
|
| rng: np.random.Generator,
|
| de_genes: Dict[str, Dict[str, float]],
|
| disease: DiseaseProfile,
|
| ) -> List[str]:
|
| markers = list(disease.markers)
|
|
|
| all_effects: Dict[str, float] = {}
|
| for effects in de_genes.values():
|
| all_effects.update(effects)
|
|
|
| for gene in markers:
|
| if gene not in all_effects:
|
| all_effects[gene] = float(rng.uniform(1.0, 2.5))
|
| for comp_effects in de_genes.values():
|
| comp_effects[gene] = all_effects[gene]
|
|
|
| n_markers = min(len(markers), int(rng.integers(3, 7)))
|
| return markers[:n_markers]
|
|
|
|
|
| def _build_trajectory(
|
| rng: np.random.Generator,
|
| tissue: str,
|
| populations: List[CellPopulation],
|
| ) -> Optional[Dict[str, Any]]:
|
| pop_names = {p.name for p in populations}
|
|
|
| for tmpl in TRAJECTORY_TEMPLATES:
|
| if tmpl.tissue == tissue:
|
| valid_branches = [
|
| branch for branch in tmpl.branches
|
| if all(node in pop_names for node in branch)
|
| ]
|
| if valid_branches:
|
| return {
|
| "root": tmpl.root_population,
|
| "n_lineages": len(valid_branches),
|
| "branching": len(valid_branches) > 1,
|
| "branches": valid_branches,
|
| }
|
|
|
| if len(populations) >= 3:
|
| root = populations[0].name
|
| branches = [[root, p.name] for p in populations[1:]]
|
| selected = branches[:int(rng.integers(2, min(4, len(branches)) + 1))]
|
| return {
|
| "root": root,
|
| "n_lineages": len(selected),
|
| "branching": len(selected) > 1,
|
| "branches": selected,
|
| }
|
|
|
| return None
|
|
|
|
|
| def _build_regulatory_network(
|
| rng: np.random.Generator,
|
| tissue: str,
|
| populations: List[CellPopulation],
|
| ) -> Dict[str, List[str]]:
|
| all_genes = set()
|
| for p in populations:
|
| all_genes.update(p.marker_genes)
|
|
|
| network: Dict[str, List[str]] = {}
|
|
|
| tissue_to_programs = {
|
| "bone_marrow": ["erythroid", "myeloid", "stem_cell"],
|
| "thymus": ["lymphoid"],
|
| "blood": ["lymphoid", "myeloid"],
|
| "spleen": ["lymphoid"],
|
| "brain": ["neuronal", "inflammatory"],
|
| "heart": ["fibrotic", "inflammatory"],
|
| "lung": ["fibrotic", "inflammatory"],
|
| "liver": ["fibrotic", "inflammatory"],
|
| "kidney": ["fibrotic", "inflammatory"],
|
| "colon": ["inflammatory", "stem_cell"],
|
| "pancreas": ["inflammatory"],
|
| "skin": ["inflammatory"],
|
| "breast": ["inflammatory"],
|
| "synovium": ["inflammatory", "lymphoid"],
|
| "aorta": ["inflammatory"],
|
| }
|
|
|
| programs = tissue_to_programs.get(tissue, ["inflammatory"])
|
| for prog_name in programs:
|
| prog = REGULATORY_TEMPLATES.get(prog_name, {})
|
| for tf, targets in prog.items():
|
| network[tf] = list(targets)
|
|
|
| if not network:
|
| for p in populations[:2]:
|
| if len(p.marker_genes) >= 2:
|
| tf = p.marker_genes[0]
|
| network[tf] = p.marker_genes[1:]
|
|
|
| return network
|
|
|
|
|
| def _build_perturbation(
|
| rng: np.random.Generator,
|
| disease: DiseaseProfile,
|
| ) -> Dict[str, Dict[str, float]]:
|
| disease_pathways = set(disease.pathways.keys())
|
|
|
| matching = [
|
| (name, tmpl) for name, tmpl in PERTURBATION_TEMPLATES.items()
|
| if tmpl.target_pathway in disease_pathways
|
| ]
|
|
|
| if matching:
|
| name, tmpl = matching[int(rng.integers(0, len(matching)))]
|
| else:
|
| name = rng.choice(list(PERTURBATION_TEMPLATES.keys()))
|
| tmpl = PERTURBATION_TEMPLATES[name]
|
|
|
| scaled: Dict[str, float] = {}
|
| for gene, effect in tmpl.gene_effects.items():
|
| scale = float(rng.uniform(0.7, 1.3))
|
| scaled[gene] = round(effect * scale, 3)
|
|
|
| return {name: scaled}
|
|
|
|
|
| def _build_technical(
|
| rng: np.random.Generator,
|
| params: dict,
|
| ) -> TechnicalState:
|
| n_batches = int(rng.integers(*params["n_batches"]))
|
| batch_effects: Dict[str, float] = {}
|
| for i in range(max(1, n_batches)):
|
| strength = float(rng.uniform(*params["noise_batch_strength"]))
|
| batch_effects[f"batch_{i}"] = round(strength, 3)
|
|
|
| return TechnicalState(
|
| batch_effects=batch_effects,
|
| dropout_rate=round(float(rng.uniform(*params["noise_dropout"])), 3),
|
| doublet_rate=round(float(rng.uniform(*params["noise_doublet"])), 3),
|
| ambient_rna_fraction=round(float(rng.uniform(*params["noise_ambient"])), 3),
|
| sample_quality=round(float(rng.uniform(*params["sample_quality"])), 3),
|
| )
|
|
|
|
|
| def _build_task(
|
| rng: np.random.Generator,
|
| disease: DiseaseProfile,
|
| tissue: str,
|
| scenario_type: str,
|
| params: dict,
|
| perturbation_effects: Dict[str, Dict[str, float]],
|
| ) -> TaskSpec:
|
| budget = float(rng.integers(*params["budget_range"]))
|
| time_days = float(rng.integers(*params["time_range"]))
|
|
|
| if scenario_type == "de":
|
| problem = (
|
| f"Identify differentially expressed genes between "
|
| f"{disease.display_name} and healthy {tissue} tissue "
|
| f"using single-cell RNA sequencing."
|
| )
|
| criteria = [
|
| f"Identify DE genes between {disease.condition_name} and healthy",
|
| "Validate at least one candidate marker",
|
| ]
|
| elif scenario_type == "trajectory":
|
| problem = (
|
| f"Infer the developmental trajectory of cell populations "
|
| f"in {tissue} tissue in the context of {disease.display_name}."
|
| )
|
| criteria = [
|
| "Reconstruct branching lineage structure",
|
| "Identify key transcription factors driving fate decisions",
|
| ]
|
| elif scenario_type == "perturbation":
|
| pert_name = next(iter(perturbation_effects), "treatment")
|
| pert_tmpl = PERTURBATION_TEMPLATES.get(pert_name)
|
| pert_desc = pert_tmpl.description if pert_tmpl else pert_name
|
| problem = (
|
| f"Determine the effect of {pert_desc} on cell states "
|
| f"in {tissue} tissue affected by {disease.display_name}."
|
| )
|
| criteria = [
|
| "Quantify shift in cell activation states",
|
| f"Identify pathways modulated by {pert_name}",
|
| "Propose validation strategy",
|
| ]
|
| else:
|
| top_marker = disease.markers[0] if disease.markers else "candidate"
|
| problem = (
|
| f"Validate candidate biomarker {top_marker} for "
|
| f"{disease.display_name} in {tissue} tissue using "
|
| f"single-cell RNA sequencing."
|
| )
|
| criteria = [
|
| f"Validate {top_marker} as a disease marker",
|
| "Confirm expression specificity across cell types",
|
| ]
|
|
|
| conditions = ["healthy", disease.condition_name]
|
| if scenario_type == "perturbation" and perturbation_effects:
|
| pert_name = next(iter(perturbation_effects))
|
| conditions = [f"untreated_{disease.condition_name}", f"{pert_name}_treated"]
|
|
|
| return TaskSpec(
|
| problem_statement=problem,
|
| modality="scRNA-seq",
|
| organism="human",
|
| tissue=tissue,
|
| conditions=conditions,
|
| budget_limit=budget,
|
| time_limit_days=time_days,
|
| success_criteria=criteria,
|
| )
|
|
|