dippoo's picture
Initial deployment - Content Engine
ed37502
raw
history blame
7.26 kB
"""Prompt template system using YAML definitions and Jinja2 rendering.
Templates define structured prompts with variable slots for character traits,
poses, outfits, emotions, camera angles, lighting, and scenes. The engine
renders these templates with provided variables to produce final prompts
for ComfyUI workflows.
"""
from __future__ import annotations
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import yaml
from jinja2 import Environment, BaseLoader
logger = logging.getLogger(__name__)
IS_HF_SPACES = os.environ.get("HF_SPACES") == "1" or os.environ.get("SPACE_ID") is not None
PROMPTS_DIR = Path("/app/config/templates/prompts") if IS_HF_SPACES else Path("D:/AI automation/content_engine/config/templates/prompts")
@dataclass
class VariableDefinition:
"""Definition of a template variable with its allowed values."""
name: str
type: str = "choice" # choice | string | number
options: list[str] = field(default_factory=list)
default: str = ""
required: bool = False
description: str = ""
@dataclass
class PromptTemplate:
"""A parsed prompt template."""
id: str
name: str
category: str = ""
rating: str = "sfw" # sfw | nsfw
base_model: str = "realistic_vision"
# LoRA specs (with Jinja2 variable references)
loras: list[dict[str, Any]] = field(default_factory=list)
# Prompt text (Jinja2 templates)
positive_prompt: str = ""
negative_prompt: str = ""
# Sampler defaults
steps: int | None = None
cfg: float | None = None
sampler_name: str | None = None
scheduler: str | None = None
width: int | None = None
height: int | None = None
# Variable definitions
variables: dict[str, VariableDefinition] = field(default_factory=dict)
# Motion (for future video support)
motion: dict[str, Any] = field(default_factory=dict)
class TemplateEngine:
"""Loads, manages, and renders prompt templates."""
def __init__(self, templates_dir: Path | None = None):
self.templates_dir = templates_dir or PROMPTS_DIR
self._templates: dict[str, PromptTemplate] = {}
self._jinja_env = Environment(loader=BaseLoader())
def load_all(self) -> None:
"""Load all YAML templates from the templates directory."""
if not self.templates_dir.exists():
logger.warning("Templates directory does not exist: %s", self.templates_dir)
return
for path in self.templates_dir.glob("*.yaml"):
try:
template = self._parse_template(path)
self._templates[template.id] = template
logger.info("Loaded template: %s", template.id)
except Exception:
logger.error("Failed to load template %s", path, exc_info=True)
def _parse_template(self, path: Path) -> PromptTemplate:
"""Parse a YAML file into a PromptTemplate."""
with open(path) as f:
data = yaml.safe_load(f)
variables = {}
for var_name, var_def in data.get("variables", {}).items():
variables[var_name] = VariableDefinition(
name=var_name,
type=var_def.get("type", "string"),
options=var_def.get("options", []),
default=var_def.get("default", ""),
required=var_def.get("required", False),
description=var_def.get("description", ""),
)
sampler = data.get("sampler", {})
return PromptTemplate(
id=data.get("id", path.stem),
name=data.get("name", path.stem),
category=data.get("category", ""),
rating=data.get("rating", "sfw"),
base_model=data.get("base_model", "realistic_vision"),
loras=data.get("loras", []),
positive_prompt=data.get("positive_prompt", ""),
negative_prompt=data.get("negative_prompt", ""),
steps=sampler.get("steps"),
cfg=sampler.get("cfg"),
sampler_name=sampler.get("sampler_name"),
scheduler=sampler.get("scheduler"),
width=sampler.get("width"),
height=sampler.get("height"),
variables=variables,
motion=data.get("motion", {}),
)
def get(self, template_id: str) -> PromptTemplate:
"""Get a loaded template by ID."""
if template_id not in self._templates:
raise KeyError(f"Template not found: {template_id}")
return self._templates[template_id]
def list_templates(self) -> list[PromptTemplate]:
"""List all loaded templates."""
return list(self._templates.values())
def render(
self,
template_id: str,
variables: dict[str, str],
) -> RenderedPrompt:
"""Render a template with the given variables.
Returns the rendered positive/negative prompts and resolved LoRA specs.
"""
template = self.get(template_id)
# Fill in defaults for missing variables
resolved_vars = {}
for var_name, var_def in template.variables.items():
if var_name in variables:
resolved_vars[var_name] = variables[var_name]
elif var_def.default:
resolved_vars[var_name] = var_def.default
elif var_def.required:
# Character-specific vars default to empty when no character selected
if var_name in ("character_trigger", "character_lora"):
resolved_vars[var_name] = ""
else:
raise ValueError(f"Required variable '{var_name}' not provided")
# Also pass through any extra variables not in the definition
for k, v in variables.items():
if k not in resolved_vars:
resolved_vars[k] = v
# Render prompts
positive = self._render_string(template.positive_prompt, resolved_vars)
negative = self._render_string(template.negative_prompt, resolved_vars)
# Render LoRA names (they may contain {{character_lora}} etc.)
rendered_loras = []
for lora_spec in template.loras:
rendered_loras.append({
"name": self._render_string(lora_spec.get("name", ""), resolved_vars),
"strength_model": lora_spec.get("strength_model", 0.85),
"strength_clip": lora_spec.get("strength_clip", 0.85),
})
return RenderedPrompt(
positive_prompt=positive,
negative_prompt=negative,
loras=rendered_loras,
variables=resolved_vars,
template=template,
)
def _render_string(self, template_str: str, variables: dict[str, str]) -> str:
"""Render a Jinja2 template string with variables."""
if not template_str:
return ""
tmpl = self._jinja_env.from_string(template_str)
return tmpl.render(**variables)
@dataclass
class RenderedPrompt:
"""Result of rendering a template with variables."""
positive_prompt: str
negative_prompt: str
loras: list[dict[str, Any]]
variables: dict[str, str]
template: PromptTemplate