"""Prompt template system using YAML definitions and Jinja2 rendering. Templates define structured prompts with variable slots for character traits, poses, outfits, emotions, camera angles, lighting, and scenes. The engine renders these templates with provided variables to produce final prompts for ComfyUI workflows. """ from __future__ import annotations import logging import os from dataclasses import dataclass, field from pathlib import Path from typing import Any import yaml from jinja2 import Environment, BaseLoader logger = logging.getLogger(__name__) IS_HF_SPACES = os.environ.get("HF_SPACES") == "1" or os.environ.get("SPACE_ID") is not None PROMPTS_DIR = Path("/app/config/templates/prompts") if IS_HF_SPACES else Path("D:/AI automation/content_engine/config/templates/prompts") @dataclass class VariableDefinition: """Definition of a template variable with its allowed values.""" name: str type: str = "choice" # choice | string | number options: list[str] = field(default_factory=list) default: str = "" required: bool = False description: str = "" @dataclass class PromptTemplate: """A parsed prompt template.""" id: str name: str category: str = "" rating: str = "sfw" # sfw | nsfw base_model: str = "realistic_vision" # LoRA specs (with Jinja2 variable references) loras: list[dict[str, Any]] = field(default_factory=list) # Prompt text (Jinja2 templates) positive_prompt: str = "" negative_prompt: str = "" # Sampler defaults steps: int | None = None cfg: float | None = None sampler_name: str | None = None scheduler: str | None = None width: int | None = None height: int | None = None # Variable definitions variables: dict[str, VariableDefinition] = field(default_factory=dict) # Motion (for future video support) motion: dict[str, Any] = field(default_factory=dict) class TemplateEngine: """Loads, manages, and renders prompt templates.""" def __init__(self, templates_dir: Path | None = None): self.templates_dir = templates_dir or PROMPTS_DIR self._templates: dict[str, PromptTemplate] = {} self._jinja_env = Environment(loader=BaseLoader()) def load_all(self) -> None: """Load all YAML templates from the templates directory.""" if not self.templates_dir.exists(): logger.warning("Templates directory does not exist: %s", self.templates_dir) return for path in self.templates_dir.glob("*.yaml"): try: template = self._parse_template(path) self._templates[template.id] = template logger.info("Loaded template: %s", template.id) except Exception: logger.error("Failed to load template %s", path, exc_info=True) def _parse_template(self, path: Path) -> PromptTemplate: """Parse a YAML file into a PromptTemplate.""" with open(path) as f: data = yaml.safe_load(f) variables = {} for var_name, var_def in data.get("variables", {}).items(): variables[var_name] = VariableDefinition( name=var_name, type=var_def.get("type", "string"), options=var_def.get("options", []), default=var_def.get("default", ""), required=var_def.get("required", False), description=var_def.get("description", ""), ) sampler = data.get("sampler", {}) return PromptTemplate( id=data.get("id", path.stem), name=data.get("name", path.stem), category=data.get("category", ""), rating=data.get("rating", "sfw"), base_model=data.get("base_model", "realistic_vision"), loras=data.get("loras", []), positive_prompt=data.get("positive_prompt", ""), negative_prompt=data.get("negative_prompt", ""), steps=sampler.get("steps"), cfg=sampler.get("cfg"), sampler_name=sampler.get("sampler_name"), scheduler=sampler.get("scheduler"), width=sampler.get("width"), height=sampler.get("height"), variables=variables, motion=data.get("motion", {}), ) def get(self, template_id: str) -> PromptTemplate: """Get a loaded template by ID.""" if template_id not in self._templates: raise KeyError(f"Template not found: {template_id}") return self._templates[template_id] def list_templates(self) -> list[PromptTemplate]: """List all loaded templates.""" return list(self._templates.values()) def render( self, template_id: str, variables: dict[str, str], ) -> RenderedPrompt: """Render a template with the given variables. Returns the rendered positive/negative prompts and resolved LoRA specs. """ template = self.get(template_id) # Fill in defaults for missing variables resolved_vars = {} for var_name, var_def in template.variables.items(): if var_name in variables: resolved_vars[var_name] = variables[var_name] elif var_def.default: resolved_vars[var_name] = var_def.default elif var_def.required: # Character-specific vars default to empty when no character selected if var_name in ("character_trigger", "character_lora"): resolved_vars[var_name] = "" else: raise ValueError(f"Required variable '{var_name}' not provided") # Also pass through any extra variables not in the definition for k, v in variables.items(): if k not in resolved_vars: resolved_vars[k] = v # Render prompts positive = self._render_string(template.positive_prompt, resolved_vars) negative = self._render_string(template.negative_prompt, resolved_vars) # Render LoRA names (they may contain {{character_lora}} etc.) rendered_loras = [] for lora_spec in template.loras: rendered_loras.append({ "name": self._render_string(lora_spec.get("name", ""), resolved_vars), "strength_model": lora_spec.get("strength_model", 0.85), "strength_clip": lora_spec.get("strength_clip", 0.85), }) return RenderedPrompt( positive_prompt=positive, negative_prompt=negative, loras=rendered_loras, variables=resolved_vars, template=template, ) def _render_string(self, template_str: str, variables: dict[str, str]) -> str: """Render a Jinja2 template string with variables.""" if not template_str: return "" tmpl = self._jinja_env.from_string(template_str) return tmpl.render(**variables) @dataclass class RenderedPrompt: """Result of rendering a template with variables.""" positive_prompt: str negative_prompt: str loras: list[dict[str, Any]] variables: dict[str, str] template: PromptTemplate