dippoo's picture
Add debug logging for DB path on startup
b341f22
raw
history blame
8.47 kB
"""SQLAlchemy database models for the content catalog and job queue."""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import (
Boolean,
DateTime,
Float,
Index,
Integer,
String,
Text,
func,
)
from sqlalchemy.ext.asyncio import AsyncAttrs, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from content_engine.config import settings
class Base(AsyncAttrs, DeclarativeBase):
pass
class Character(Base):
__tablename__ = "characters"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
name: Mapped[str] = mapped_column(String(128), nullable=False)
trigger_word: Mapped[str] = mapped_column(String(128), nullable=False)
lora_filename: Mapped[str] = mapped_column(String(256), nullable=False)
lora_strength: Mapped[float] = mapped_column(Float, default=0.85)
default_checkpoint: Mapped[str | None] = mapped_column(String(256))
description: Mapped[str | None] = mapped_column(Text)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.now()
)
class Image(Base):
__tablename__ = "images"
id: Mapped[str] = mapped_column(String(36), primary_key=True)
batch_id: Mapped[str | None] = mapped_column(String(36), index=True)
character_id: Mapped[str | None] = mapped_column(String(64), index=True)
template_id: Mapped[str | None] = mapped_column(String(128))
content_rating: Mapped[str] = mapped_column(String(8), index=True) # sfw | nsfw
# Generation parameters
positive_prompt: Mapped[str | None] = mapped_column(Text)
negative_prompt: Mapped[str | None] = mapped_column(Text)
checkpoint: Mapped[str | None] = mapped_column(String(256))
loras_json: Mapped[str | None] = mapped_column(Text) # JSON array
seed: Mapped[int | None] = mapped_column(Integer)
steps: Mapped[int | None] = mapped_column(Integer)
cfg: Mapped[float | None] = mapped_column(Float)
sampler: Mapped[str | None] = mapped_column(String(64))
scheduler: Mapped[str | None] = mapped_column(String(64))
width: Mapped[int | None] = mapped_column(Integer)
height: Mapped[int | None] = mapped_column(Integer)
# Searchable variation attributes
pose: Mapped[str | None] = mapped_column(String(128))
outfit: Mapped[str | None] = mapped_column(String(128))
emotion: Mapped[str | None] = mapped_column(String(128))
camera_angle: Mapped[str | None] = mapped_column(String(128))
lighting: Mapped[str | None] = mapped_column(String(128))
scene: Mapped[str | None] = mapped_column(String(128))
# File info
file_path: Mapped[str] = mapped_column(String(512), nullable=False)
file_hash: Mapped[str | None] = mapped_column(String(64))
file_size: Mapped[int | None] = mapped_column(Integer)
generation_backend: Mapped[str | None] = mapped_column(String(32)) # local | cloud
comfyui_prompt_id: Mapped[str | None] = mapped_column(String(36))
generation_time_seconds: Mapped[float | None] = mapped_column(Float)
# Quality and publishing
quality_score: Mapped[float | None] = mapped_column(Float)
is_approved: Mapped[bool] = mapped_column(Boolean, default=False)
is_published: Mapped[bool] = mapped_column(Boolean, default=False)
published_platform: Mapped[str | None] = mapped_column(String(64))
published_at: Mapped[datetime | None] = mapped_column(DateTime)
scheduled_at: Mapped[datetime | None] = mapped_column(DateTime)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.now()
)
__table_args__ = (
Index("idx_images_approved", "is_approved", postgresql_where=(is_approved == True)), # noqa: E712
Index(
"idx_images_unpublished",
"is_published",
"is_approved",
),
)
class GenerationJob(Base):
__tablename__ = "generation_jobs"
id: Mapped[str] = mapped_column(String(36), primary_key=True)
batch_id: Mapped[str | None] = mapped_column(String(36), index=True)
character_id: Mapped[str | None] = mapped_column(String(64))
template_id: Mapped[str | None] = mapped_column(String(128))
content_rating: Mapped[str | None] = mapped_column(String(8))
variables_json: Mapped[str | None] = mapped_column(Text)
workflow_json: Mapped[str | None] = mapped_column(Text)
backend: Mapped[str | None] = mapped_column(String(32)) # local | replicate | runpod
status: Mapped[str] = mapped_column(
String(16), default="pending", index=True
) # pending | queued | running | completed | failed
comfyui_prompt_id: Mapped[str | None] = mapped_column(String(36))
cloud_job_id: Mapped[str | None] = mapped_column(String(128))
result_image_id: Mapped[str | None] = mapped_column(String(36))
error_message: Mapped[str | None] = mapped_column(Text)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.now()
)
started_at: Mapped[datetime | None] = mapped_column(DateTime)
completed_at: Mapped[datetime | None] = mapped_column(DateTime)
class ScheduledPost(Base):
__tablename__ = "scheduled_posts"
id: Mapped[str] = mapped_column(String(36), primary_key=True)
image_id: Mapped[str] = mapped_column(String(36), nullable=False)
platform: Mapped[str] = mapped_column(String(64), nullable=False)
scheduled_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
caption: Mapped[str | None] = mapped_column(Text)
is_teaser: Mapped[bool] = mapped_column(Boolean, default=False)
status: Mapped[str] = mapped_column(
String(16), default="pending"
) # pending | published | failed | cancelled
published_at: Mapped[datetime | None] = mapped_column(DateTime)
error_message: Mapped[str | None] = mapped_column(Text)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.now()
)
__table_args__ = (
Index("idx_scheduled_pending", "status", "scheduled_at"),
)
class TrainingJob(Base):
__tablename__ = "training_jobs"
id: Mapped[str] = mapped_column(String(36), primary_key=True)
name: Mapped[str] = mapped_column(String(128), nullable=False)
status: Mapped[str] = mapped_column(String(32), default="pending", index=True)
progress: Mapped[float] = mapped_column(Float, default=0.0)
current_epoch: Mapped[int] = mapped_column(Integer, default=0)
total_epochs: Mapped[int] = mapped_column(Integer, default=0)
current_step: Mapped[int] = mapped_column(Integer, default=0)
total_steps: Mapped[int] = mapped_column(Integer, default=0)
loss: Mapped[float | None] = mapped_column(Float)
started_at: Mapped[float | None] = mapped_column(Float)
completed_at: Mapped[float | None] = mapped_column(Float)
output_path: Mapped[str | None] = mapped_column(String(512))
error: Mapped[str | None] = mapped_column(Text)
log_text: Mapped[str | None] = mapped_column(Text) # newline-separated log lines
pod_id: Mapped[str | None] = mapped_column(String(64))
gpu_type: Mapped[str | None] = mapped_column(String(64))
backend: Mapped[str] = mapped_column(String(16), default="runpod")
base_model: Mapped[str | None] = mapped_column(String(64))
model_type: Mapped[str | None] = mapped_column(String(16))
trigger_word: Mapped[str | None] = mapped_column(String(128))
image_upload_dir: Mapped[str | None] = mapped_column(String(512))
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.now()
)
# --- Engine / Session factories ---
# Ensure database directories exist (critical for HF Spaces first run)
import logging as _logging
from pathlib import Path as _Path
_db_url = settings.database.url
_db_path = _db_url.replace("sqlite+aiosqlite:///", "")
_db_dir = _Path(_db_path).parent
_db_dir.mkdir(parents=True, exist_ok=True)
_logging.getLogger(__name__).info("DB path: %s (dir exists: %s)", _db_path, _db_dir.exists())
_catalog_engine = create_async_engine(
_db_url,
echo=False,
connect_args={"check_same_thread": False}, # SQLite specific
)
catalog_session_factory = async_sessionmaker(
_catalog_engine, expire_on_commit=False
)
async def init_db() -> None:
"""Create all tables. Call once at startup."""
async with _catalog_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)