Spaces:
Running
Running
| """Prompt template system using YAML definitions and Jinja2 rendering. | |
| Templates define structured prompts with variable slots for character traits, | |
| poses, outfits, emotions, camera angles, lighting, and scenes. The engine | |
| renders these templates with provided variables to produce final prompts | |
| for ComfyUI workflows. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any | |
| import yaml | |
| from jinja2 import Environment, BaseLoader | |
| logger = logging.getLogger(__name__) | |
| IS_HF_SPACES = os.environ.get("HF_SPACES") == "1" or os.environ.get("SPACE_ID") is not None | |
| PROMPTS_DIR = Path("/app/config/templates/prompts") if IS_HF_SPACES else Path("D:/AI automation/content_engine/config/templates/prompts") | |
| class VariableDefinition: | |
| """Definition of a template variable with its allowed values.""" | |
| name: str | |
| type: str = "choice" # choice | string | number | |
| options: list[str] = field(default_factory=list) | |
| default: str = "" | |
| required: bool = False | |
| description: str = "" | |
| class PromptTemplate: | |
| """A parsed prompt template.""" | |
| id: str | |
| name: str | |
| category: str = "" | |
| rating: str = "sfw" # sfw | nsfw | |
| base_model: str = "realistic_vision" | |
| # LoRA specs (with Jinja2 variable references) | |
| loras: list[dict[str, Any]] = field(default_factory=list) | |
| # Prompt text (Jinja2 templates) | |
| positive_prompt: str = "" | |
| negative_prompt: str = "" | |
| # Sampler defaults | |
| steps: int | None = None | |
| cfg: float | None = None | |
| sampler_name: str | None = None | |
| scheduler: str | None = None | |
| width: int | None = None | |
| height: int | None = None | |
| # Variable definitions | |
| variables: dict[str, VariableDefinition] = field(default_factory=dict) | |
| # Motion (for future video support) | |
| motion: dict[str, Any] = field(default_factory=dict) | |
| class TemplateEngine: | |
| """Loads, manages, and renders prompt templates.""" | |
| def __init__(self, templates_dir: Path | None = None): | |
| self.templates_dir = templates_dir or PROMPTS_DIR | |
| self._templates: dict[str, PromptTemplate] = {} | |
| self._jinja_env = Environment(loader=BaseLoader()) | |
| def load_all(self) -> None: | |
| """Load all YAML templates from the templates directory.""" | |
| if not self.templates_dir.exists(): | |
| logger.warning("Templates directory does not exist: %s", self.templates_dir) | |
| return | |
| for path in self.templates_dir.glob("*.yaml"): | |
| try: | |
| template = self._parse_template(path) | |
| self._templates[template.id] = template | |
| logger.info("Loaded template: %s", template.id) | |
| except Exception: | |
| logger.error("Failed to load template %s", path, exc_info=True) | |
| def _parse_template(self, path: Path) -> PromptTemplate: | |
| """Parse a YAML file into a PromptTemplate.""" | |
| with open(path) as f: | |
| data = yaml.safe_load(f) | |
| variables = {} | |
| for var_name, var_def in data.get("variables", {}).items(): | |
| variables[var_name] = VariableDefinition( | |
| name=var_name, | |
| type=var_def.get("type", "string"), | |
| options=var_def.get("options", []), | |
| default=var_def.get("default", ""), | |
| required=var_def.get("required", False), | |
| description=var_def.get("description", ""), | |
| ) | |
| sampler = data.get("sampler", {}) | |
| return PromptTemplate( | |
| id=data.get("id", path.stem), | |
| name=data.get("name", path.stem), | |
| category=data.get("category", ""), | |
| rating=data.get("rating", "sfw"), | |
| base_model=data.get("base_model", "realistic_vision"), | |
| loras=data.get("loras", []), | |
| positive_prompt=data.get("positive_prompt", ""), | |
| negative_prompt=data.get("negative_prompt", ""), | |
| steps=sampler.get("steps"), | |
| cfg=sampler.get("cfg"), | |
| sampler_name=sampler.get("sampler_name"), | |
| scheduler=sampler.get("scheduler"), | |
| width=sampler.get("width"), | |
| height=sampler.get("height"), | |
| variables=variables, | |
| motion=data.get("motion", {}), | |
| ) | |
| def get(self, template_id: str) -> PromptTemplate: | |
| """Get a loaded template by ID.""" | |
| if template_id not in self._templates: | |
| raise KeyError(f"Template not found: {template_id}") | |
| return self._templates[template_id] | |
| def list_templates(self) -> list[PromptTemplate]: | |
| """List all loaded templates.""" | |
| return list(self._templates.values()) | |
| def render( | |
| self, | |
| template_id: str, | |
| variables: dict[str, str], | |
| ) -> RenderedPrompt: | |
| """Render a template with the given variables. | |
| Returns the rendered positive/negative prompts and resolved LoRA specs. | |
| """ | |
| template = self.get(template_id) | |
| # Fill in defaults for missing variables | |
| resolved_vars = {} | |
| for var_name, var_def in template.variables.items(): | |
| if var_name in variables: | |
| resolved_vars[var_name] = variables[var_name] | |
| elif var_def.default: | |
| resolved_vars[var_name] = var_def.default | |
| elif var_def.required: | |
| # Character-specific vars default to empty when no character selected | |
| if var_name in ("character_trigger", "character_lora"): | |
| resolved_vars[var_name] = "" | |
| else: | |
| raise ValueError(f"Required variable '{var_name}' not provided") | |
| # Also pass through any extra variables not in the definition | |
| for k, v in variables.items(): | |
| if k not in resolved_vars: | |
| resolved_vars[k] = v | |
| # Render prompts | |
| positive = self._render_string(template.positive_prompt, resolved_vars) | |
| negative = self._render_string(template.negative_prompt, resolved_vars) | |
| # Render LoRA names (they may contain {{character_lora}} etc.) | |
| rendered_loras = [] | |
| for lora_spec in template.loras: | |
| rendered_loras.append({ | |
| "name": self._render_string(lora_spec.get("name", ""), resolved_vars), | |
| "strength_model": lora_spec.get("strength_model", 0.85), | |
| "strength_clip": lora_spec.get("strength_clip", 0.85), | |
| }) | |
| return RenderedPrompt( | |
| positive_prompt=positive, | |
| negative_prompt=negative, | |
| loras=rendered_loras, | |
| variables=resolved_vars, | |
| template=template, | |
| ) | |
| def _render_string(self, template_str: str, variables: dict[str, str]) -> str: | |
| """Render a Jinja2 template string with variables.""" | |
| if not template_str: | |
| return "" | |
| tmpl = self._jinja_env.from_string(template_str) | |
| return tmpl.render(**variables) | |
| class RenderedPrompt: | |
| """Result of rendering a template with variables.""" | |
| positive_prompt: str | |
| negative_prompt: str | |
| loras: list[dict[str, Any]] | |
| variables: dict[str, str] | |
| template: PromptTemplate | |