"""Content variation engine for generating combinatorial batches. Given a prompt template and a character, the variation engine produces multiple generation jobs with different combinations of poses, outfits, emotions, camera angles, and other variable attributes. """ from __future__ import annotations import itertools import random import uuid from dataclasses import dataclass, field from typing import Any from content_engine.services.template_engine import TemplateEngine @dataclass class CharacterProfile: """Character configuration loaded from YAML.""" id: str name: str trigger_word: str lora_filename: str lora_strength: float = 0.85 default_checkpoint: str | None = None style_loras: list[dict[str, Any]] = field(default_factory=list) description: str = "" physical_traits: dict[str, str] = field(default_factory=dict) @dataclass class VariationJob: """A single generation job produced by the variation engine.""" job_id: str batch_id: str character: CharacterProfile template_id: str content_rating: str variables: dict[str, str] seed: int loras: list[dict[str, Any]] class VariationEngine: """Generates batches of variation jobs from templates.""" def __init__(self, template_engine: TemplateEngine): self.template_engine = template_engine def generate_batch( self, template_id: str, character: CharacterProfile, *, content_rating: str = "sfw", count: int = 10, variation_mode: str = "random", # curated | random | exhaustive pin: dict[str, str] | None = None, seed_strategy: str = "random", # random | sequential | fixed base_seed: int | None = None, ) -> list[VariationJob]: """Generate a batch of variation jobs. Args: template_id: Which prompt template to use. character: Character profile for LoRA and trigger word. content_rating: "sfw" or "nsfw". count: Number of variations to generate. variation_mode: How to select variable combinations. pin: Variables to keep fixed across all variations. seed_strategy: How to assign seeds. base_seed: Starting seed for sequential strategy. """ template = self.template_engine.get(template_id) pin = pin or {} batch_id = str(uuid.uuid4()) # Build variable combinations combos = self._select_combinations(template_id, count, variation_mode, pin) # Inject character-specific variables for combo in combos: combo["character_trigger"] = character.trigger_word combo["character_lora"] = character.lora_filename # Build LoRA list for each job base_loras = [ { "name": character.lora_filename, "strength_model": character.lora_strength, "strength_clip": character.lora_strength, } ] for style_lora in character.style_loras: base_loras.append(style_lora) # Create jobs jobs = [] for i, combo in enumerate(combos): seed = self._get_seed(seed_strategy, base_seed, i) jobs.append( VariationJob( job_id=str(uuid.uuid4()), batch_id=batch_id, character=character, template_id=template_id, content_rating=content_rating, variables=combo, seed=seed, loras=list(base_loras), ) ) return jobs def _select_combinations( self, template_id: str, count: int, mode: str, pin: dict[str, str], ) -> list[dict[str, str]]: """Select variable combinations based on mode.""" template = self.template_engine.get(template_id) if mode == "random": return self._random_combos(template.variables, count, pin) elif mode == "exhaustive": return self._exhaustive_combos(template.variables, count, pin) else: # "curated" falls back to random for now return self._random_combos(template.variables, count, pin) def _random_combos( self, variables: dict, count: int, pin: dict[str, str], ) -> list[dict[str, str]]: """Generate random combinations.""" combos = [] for _ in range(count): combo: dict[str, str] = {} for var_name, var_def in variables.items(): if var_name in pin: combo[var_name] = pin[var_name] elif var_def.type == "choice" and var_def.options: combo[var_name] = random.choice(var_def.options) elif var_def.default: combo[var_name] = var_def.default combos.append(combo) return combos def _exhaustive_combos( self, variables: dict, count: int, pin: dict[str, str], ) -> list[dict[str, str]]: """Generate exhaustive (cartesian product) combinations, capped at count.""" axes: list[list[tuple[str, str]]] = [] for var_name, var_def in variables.items(): if var_name in pin: axes.append([(var_name, pin[var_name])]) elif var_def.type == "choice" and var_def.options: axes.append([(var_name, opt) for opt in var_def.options]) if not axes: return [{}] * count all_combos = [dict(combo) for combo in itertools.product(*axes)] if len(all_combos) > count: all_combos = random.sample(all_combos, count) return all_combos def _get_seed( self, strategy: str, base_seed: int | None, index: int ) -> int: """Generate a seed based on strategy.""" if strategy == "fixed" and base_seed is not None: return base_seed elif strategy == "sequential" and base_seed is not None: return base_seed + index else: return random.randint(0, 2**32 - 1)