| """Task generator β produces (TaskSpec, FullLatentState) pairs for episodes.
|
|
|
| Supports three modes:
|
| 1. Select from the pre-defined scenario library.
|
| 2. Randomly perturb a scenario for domain-randomisation.
|
| 3. Compose a fully procedural scenario (tissue Γ modality Γ difficulty).
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| from typing import List, Optional, Tuple
|
|
|
| import numpy as np
|
|
|
| from models import TaskSpec, tools_for_modality, assays_for_modality
|
|
|
| from server.simulator.latent_state import (
|
| CellPopulation,
|
| ExperimentProgress,
|
| FullLatentState,
|
| GeneProgram,
|
| LatentBiologicalState,
|
| ResourceState,
|
| TechnicalState,
|
| )
|
| from .scenarios import SCENARIO_LIBRARY, Scenario
|
| from .procedural_generator import generate_procedural_scenarios
|
|
|
|
|
| class TaskGenerator:
|
| """Generates task + latent-state pairs for environment episodes."""
|
|
|
| def __init__(
|
| self,
|
| scenarios: Optional[List[Scenario]] = None,
|
| domain_randomise: bool = True,
|
| ):
|
| if scenarios is not None:
|
| self.scenarios = scenarios
|
| else:
|
| self.scenarios = list(SCENARIO_LIBRARY) + generate_procedural_scenarios(n=20, seed=42)
|
| self.domain_randomise = domain_randomise
|
|
|
| def generate(
|
| self,
|
| *,
|
| seed: Optional[int] = None,
|
| scenario_name: Optional[str] = None,
|
| ) -> Tuple[TaskSpec, FullLatentState]:
|
| rng = np.random.default_rng(seed)
|
|
|
| if scenario_name:
|
| scenario = self._find_scenario(scenario_name)
|
| else:
|
| idx = int(rng.integers(0, len(self.scenarios)))
|
| scenario = self.scenarios[idx]
|
|
|
| task = scenario.task.model_copy(deep=True)
|
| biology = scenario.biology.model_copy(deep=True)
|
| technical = scenario.technical.model_copy(deep=True)
|
|
|
| if self.domain_randomise:
|
| self._randomise(rng, task, biology, technical)
|
|
|
|
|
| compatible_tools = [t.name for t in tools_for_modality(task.modality)]
|
| compatible_assays = [a.name for a in assays_for_modality(task.modality)]
|
| if compatible_tools:
|
| task.available_tools = compatible_tools
|
| if compatible_assays:
|
| task.available_assays = compatible_assays
|
|
|
| latent = FullLatentState(
|
| biology=biology,
|
| technical=technical,
|
| progress=ExperimentProgress(),
|
| resources=ResourceState(
|
| budget_total=task.budget_limit,
|
| time_limit_days=task.time_limit_days,
|
| ),
|
| hidden_failure_conditions=list(scenario.hidden_failure_conditions),
|
| task_modality=task.modality,
|
| rng_seed=seed or 0,
|
| )
|
| return task, latent
|
|
|
| def list_scenarios(self) -> List[str]:
|
| return [s.name for s in self.scenarios]
|
|
|
|
|
|
|
| def _find_scenario(self, name: str) -> Scenario:
|
| for s in self.scenarios:
|
| if s.name == name:
|
| return s
|
| available = ", ".join(self.list_scenarios())
|
| raise ValueError(f"Unknown scenario '{name}'. Available: {available}")
|
|
|
| def _randomise(
|
| self,
|
| rng: np.random.Generator,
|
| task: TaskSpec,
|
| bio: LatentBiologicalState,
|
| tech: TechnicalState,
|
| ) -> None:
|
| budget_scale = float(rng.uniform(0.7, 1.3))
|
| task.budget_limit *= budget_scale
|
| task.time_limit_days *= float(rng.uniform(0.8, 1.2))
|
|
|
| tech.dropout_rate = float(np.clip(
|
| tech.dropout_rate + rng.normal(0, 0.02), 0.01, 0.3
|
| ))
|
| tech.doublet_rate = float(np.clip(
|
| tech.doublet_rate + rng.normal(0, 0.01), 0.01, 0.15
|
| ))
|
| tech.sample_quality = float(np.clip(
|
| tech.sample_quality + rng.normal(0, 0.05), 0.5, 1.0
|
| ))
|
| tech.ambient_rna_fraction = float(np.clip(
|
| tech.ambient_rna_fraction + rng.normal(0, 0.01), 0.01, 0.15
|
| ))
|
| for batch_id in list(tech.batch_effects.keys()):
|
| tech.batch_effects[batch_id] = float(np.clip(
|
| tech.batch_effects[batch_id] + rng.normal(0, 0.03), 0.0, 0.4
|
| ))
|
|
|
| for pop in bio.cell_populations:
|
| pop.proportion = float(np.clip(
|
| pop.proportion * rng.uniform(0.8, 1.2), 0.01, 0.8
|
| ))
|
| total = sum(p.proportion for p in bio.cell_populations) or 1.0
|
| for pop in bio.cell_populations:
|
| pop.proportion /= total
|
|
|
| for comparison, effects in bio.true_de_genes.items():
|
| for gene in list(effects.keys()):
|
| effects[gene] *= float(rng.uniform(0.8, 1.2))
|
|
|
| bio.n_true_cells = max(
|
| 1000,
|
| int(bio.n_true_cells * rng.uniform(0.6, 1.4)),
|
| )
|
|
|