Spaces:
Running
Running
import os | |
import subprocess | |
from datetime import datetime | |
from pathlib import Path | |
from fastapi import FastAPI, HTTPException, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.middleware.gzip import GZipMiddleware | |
from fastapi.responses import FileResponse, JSONResponse, ORJSONResponse | |
from fastapi.staticfiles import StaticFiles | |
from dotenv import load_dotenv | |
load_dotenv() | |
from app.config import settings | |
from app.routers import upload, caption, metadata, models | |
from app.routers.images import router as images_router | |
from app.routers.prompts import router as prompts_router | |
from app.routers.admin import router as admin_router | |
from app.routers.schemas import router as schemas_router | |
app = FastAPI( | |
title="PromptAid Vision", | |
default_response_class=ORJSONResponse, | |
) | |
# -------------------------------------------------------------------- | |
# Compression | |
# -------------------------------------------------------------------- | |
app.add_middleware(GZipMiddleware, minimum_size=500) | |
# -------------------------------------------------------------------- | |
# Logging middleware (simple) | |
# -------------------------------------------------------------------- | |
async def log_requests(request: Request, call_next): | |
print(f"DEBUG: {request.method} {request.url.path}") | |
response = await call_next(request) | |
print(f"DEBUG: {request.method} {request.url.path} -> {response.status_code}") | |
return response | |
# -------------------------------------------------------------------- | |
# Cache headers (assets long-cache, HTML no-cache, API no-store) | |
# -------------------------------------------------------------------- | |
async def add_cache_headers(request: Request, call_next): | |
response = await call_next(request) | |
p = request.url.path | |
if p.startswith("/assets/") or p.startswith("/images/"): | |
response.headers["Cache-Control"] = "public, max-age=31536000, immutable" | |
response.headers["Vary"] = "Accept-Encoding" | |
elif p in ("/sw.js", "/manifest.webmanifest", "/vite.svg"): | |
# SW updates should be detected; keep shortish cache here (or no-cache for sw.js) | |
if p == "/sw.js": | |
response.headers["Cache-Control"] = "no-cache" | |
else: | |
response.headers["Cache-Control"] = "public, max-age=3600" | |
response.headers["Vary"] = "Accept-Encoding" | |
elif p == "/" or p.endswith(".html"): | |
response.headers["Cache-Control"] = "no-cache" | |
elif p.startswith("/api/"): | |
response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate" | |
response.headers["Pragma"] = "no-cache" | |
response.headers["Expires"] = "0" | |
# response class is ORJSONResponse already; no need to force Content-Type here | |
return response | |
# -------------------------------------------------------------------- | |
# CORS | |
# -------------------------------------------------------------------- | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=[ | |
"http://localhost:3000", | |
"http://localhost:5173", | |
"http://localhost:8000", | |
], | |
allow_origin_regex=r"https://.*\.hf\.space$", | |
allow_credentials=False, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# -------------------------------------------------------------------- | |
# API Routers | |
# -------------------------------------------------------------------- | |
app.include_router(caption.router, prefix="/api", tags=["captions"]) | |
app.include_router(metadata.router, prefix="/api", tags=["metadata"]) | |
app.include_router(models.router, prefix="/api", tags=["models"]) | |
app.include_router(upload.router, prefix="/api/images", tags=["images"]) | |
app.include_router(images_router, prefix="/api/contribute", tags=["contribute"]) | |
app.include_router(prompts_router, prefix="/api/prompts", tags=["prompts"]) | |
app.include_router(admin_router, prefix="/api/admin", tags=["admin"]) | |
app.include_router(schemas_router, prefix="/api", tags=["schemas"]) | |
# Handle /api/images and /api/prompts without trailing slash (avoid 307) | |
async def list_images_no_slash(): | |
from app.routers.upload import list_images | |
from app.database import SessionLocal | |
db = SessionLocal() | |
try: | |
return list_images(db) | |
finally: | |
db.close() | |
async def list_prompts_no_slash(): | |
from app.routers.prompts import get_prompts | |
from app.database import SessionLocal | |
db = SessionLocal() | |
try: | |
return get_prompts(db) | |
finally: | |
db.close() | |
async def create_prompt_no_slash(prompt_data: dict): | |
from app.routers.prompts import create_prompt | |
from app.database import SessionLocal | |
from app.schemas import PromptCreate | |
db = SessionLocal() | |
try: | |
prompt_create = PromptCreate(**prompt_data) | |
return create_prompt(prompt_create, db) | |
finally: | |
db.close() | |
# -------------------------------------------------------------------- | |
# Health / Performance | |
# -------------------------------------------------------------------- | |
async def health(): | |
return {"status": "ok"} | |
async def performance(): | |
import psutil, time | |
return { | |
"timestamp": time.time(), | |
"memory_usage": psutil.virtual_memory().percent, | |
"cpu_usage": psutil.cpu_percent(), | |
"compression_enabled": True, | |
"orjson_enabled": True, | |
"cache_headers": True, | |
} | |
# -------------------------------------------------------------------- | |
# Static dir resolution (ALWAYS a Path) | |
# -------------------------------------------------------------------- | |
APP_DIR = Path(__file__).resolve().parent | |
CANDIDATES = [ | |
APP_DIR / "static", # py_backend/static | |
APP_DIR.parent / "static", # ../static | |
Path("/app") / "static", # container path | |
Path("/app/app") / "static", # some containers use /app/app | |
] | |
STATIC_DIR = next((p for p in CANDIDATES if p.is_dir()), APP_DIR / "static") | |
print(f"Serving static from: {STATIC_DIR}") | |
# -------------------------------------------------------------------- | |
# Explicit top-level static files | |
# -------------------------------------------------------------------- | |
def manifest(): | |
return FileResponse( | |
STATIC_DIR / "manifest.webmanifest", | |
media_type="application/manifest+json", | |
headers={"Cache-Control": "public, max-age=31536000, immutable"}, | |
) | |
def sw(): | |
return FileResponse(STATIC_DIR / "sw.js", headers={"Cache-Control": "no-cache"}) | |
def vite_svg(): | |
svg = STATIC_DIR / "vite.svg" | |
if svg.is_file(): | |
return FileResponse(svg) | |
raise HTTPException(status_code=404, detail="Icon not found") | |
# -------------------------------------------------------------------- | |
# Mount hashed assets at /assets | |
# -------------------------------------------------------------------- | |
if (STATIC_DIR / "assets").is_dir(): | |
app.mount("/assets", StaticFiles(directory=STATIC_DIR / "assets"), name="assets") | |
# Serve index at / | |
def index(): | |
index_html = STATIC_DIR / "index.html" | |
if index_html.is_file(): | |
return FileResponse(index_html, media_type="text/html", headers={"Cache-Control": "no-cache"}) | |
raise HTTPException(status_code=404, detail="App not found") | |
# -------------------------------------------------------------------- | |
# Uploads (local storage only) | |
# -------------------------------------------------------------------- | |
async def serve_upload(file_path: str): | |
"""Serve uploaded files from local storage""" | |
if settings.STORAGE_PROVIDER != "local": | |
raise HTTPException(status_code=404, detail="Local storage not enabled") | |
file_path_full = os.path.join(settings.STORAGE_DIR, file_path) | |
if not os.path.exists(file_path_full): | |
raise HTTPException(status_code=404, detail="File not found") | |
return FileResponse(file_path_full) | |
# -------------------------------------------------------------------- | |
# Debug endpoints | |
# -------------------------------------------------------------------- | |
async def debug_routes(): | |
routes = [] | |
for route in app.routes: | |
if hasattr(route, "path"): | |
routes.append({ | |
"path": route.path, | |
"name": getattr(route, "name", "N/A"), | |
"methods": list(route.methods) if hasattr(route, "methods") else [], | |
}) | |
return {"routes": routes} | |
async def debug(): | |
return { | |
"message": "Backend is working", | |
"timestamp": datetime.now().isoformat(), | |
"routes": [route.path for route in app.routes], | |
} | |
async def debug_static(): | |
return { | |
"static_dir": str(STATIC_DIR), | |
"exists": STATIC_DIR.exists(), | |
"is_dir": STATIC_DIR.is_dir(), | |
"current_dir": os.getcwd(), | |
"app_dir": str(APP_DIR), | |
"parent_dir": str(APP_DIR.parent), | |
"sw_exists": (STATIC_DIR / "sw.js").exists(), | |
"sw_path": str(STATIC_DIR / "sw.js"), | |
"static_files": [p.name for p in STATIC_DIR.iterdir()] if STATIC_DIR.exists() else [], | |
} | |
def test_service_worker(): | |
sw_path = STATIC_DIR / "sw.js" | |
if sw_path.is_file(): | |
return FileResponse(sw_path, media_type="application/javascript") | |
raise HTTPException(status_code=404, detail="Service Worker not found") | |
# -------------------------------------------------------------------- | |
# SPA fallback LAST (doesn't swallow API/debug) | |
# -------------------------------------------------------------------- | |
def spa_fallback(full_path: str): | |
if full_path.startswith("api/"): | |
raise HTTPException(status_code=404, detail="API route not found") | |
index_html = STATIC_DIR / "index.html" | |
if index_html.is_file(): | |
return FileResponse(index_html, media_type="text/html") | |
raise HTTPException(status_code=404, detail="App not found") | |
# -------------------------------------------------------------------- | |
# Startup helpers | |
# -------------------------------------------------------------------- | |
def run_migrations(): | |
"""Run database migrations on startup""" | |
try: | |
print("Running database migrations...") | |
current_dir = os.getcwd() | |
print(f"Current working directory: {current_dir}") | |
try: | |
result = subprocess.run(["which", "alembic"], capture_output=True, text=True) | |
print(f"Alembic location: {result.stdout.strip() if result.stdout else 'Not found'}") | |
except Exception as e: | |
print(f"Could not check alembic location: {e}") | |
print(f"Checking if /app exists: {os.path.exists('/app')}") | |
if os.path.exists('/app'): | |
print(f"Contents of /app: {os.listdir('/app')}") | |
alembic_paths = [ | |
"alembic.ini", | |
"../alembic.ini", | |
"py_backend/alembic.ini", | |
"/app/alembic.ini", | |
] | |
alembic_dir = None | |
for path in alembic_paths: | |
if os.path.exists(path): | |
alembic_dir = os.path.dirname(path) | |
print(f"Found alembic.ini at: {path}") | |
break | |
if not alembic_dir: | |
print("Could not find alembic.ini - using current directory") | |
alembic_dir = current_dir | |
try: | |
print(f"Running alembic upgrade head from: {alembic_dir}") | |
result = subprocess.run( | |
["alembic", "upgrade", "head"], | |
cwd=alembic_dir, | |
capture_output=True, | |
text=True, | |
timeout=60, | |
) | |
print(f"Alembic return code: {result.returncode}") | |
print(f"Alembic stdout: {result.stdout}") | |
print(f"Alembic stderr: {result.stderr}") | |
if result.returncode == 0: | |
print("Database migrations completed successfully") | |
else: | |
print("Database migrations failed") | |
print("Trying fallback: create tables directly...") | |
try: | |
from app.database import engine | |
from app.models import Base | |
Base.metadata.create_all(bind=engine) | |
print("Tables created directly via SQLAlchemy") | |
except Exception as fallback_error: | |
print(f"Fallback also failed: {fallback_error}") | |
except Exception as e: | |
print(f"Error running alembic: {e}") | |
except Exception as e: | |
print(f"Could not run migrations: {e}") | |
def ensure_storage_ready(): | |
"""Ensure storage is ready before starting the app""" | |
print(f"Storage provider from settings: '{settings.STORAGE_PROVIDER}'") | |
print(f"S3 endpoint from settings: '{settings.S3_ENDPOINT}'") | |
print(f"S3 bucket from settings: '{settings.S3_BUCKET}'") | |
if settings.STORAGE_PROVIDER == "s3": | |
try: | |
print("Checking S3 storage connection...") | |
from app.storage import _ensure_bucket | |
_ensure_bucket() | |
print("S3 storage ready") | |
except Exception as e: | |
print(f"Warning: S3 storage not ready: {e}") | |
print("Storage operations may fail until S3 is available") | |
elif settings.STORAGE_PROVIDER == "local": | |
print("Using local storage - no external dependencies") | |
else: | |
print(f"Unknown storage provider: {settings.STORAGE_PROVIDER}") | |
# -------------------------------------------------------------------- | |
# VLM service registration on startup | |
# -------------------------------------------------------------------- | |
from app.services.vlm_service import vlm_manager | |
# Providers | |
from app.services.stub_vlm_service import StubVLMService | |
from app.services.gpt4v_service import GPT4VService | |
from app.services.gemini_service import GeminiService | |
from app.services.huggingface_service import ProvidersGenericVLMService | |
from app.database import SessionLocal | |
from app import crud | |
import asyncio | |
async def register_vlm_services() -> None: | |
"""Register OpenAI, Gemini, and Hugging Face models at startup (non-blocking).""" | |
print("Registering VLM services...") | |
# Always have a stub as a safe fallback | |
try: | |
vlm_manager.register_service(StubVLMService()) | |
print("β STUB_MODEL registered") | |
except Exception as e: | |
print(f"β Failed to register STUB_MODEL: {e}") | |
# OpenAI GPT-4V (if configured) | |
if settings.OPENAI_API_KEY: | |
try: | |
vlm_manager.register_service(GPT4VService(settings.OPENAI_API_KEY)) | |
print("β GPT-4 Vision service registered") | |
except Exception as e: | |
print(f"β GPT-4 Vision service failed to register: {e}") | |
else: | |
print("β GPT-4 Vision not configured (OPENAI_API_KEY missing)") | |
# Google Gemini (if configured) | |
if settings.GOOGLE_API_KEY: | |
try: | |
vlm_manager.register_service(GeminiService(settings.GOOGLE_API_KEY)) | |
print("β Gemini service registered") | |
except Exception as e: | |
print(f"β Gemini service failed to register: {e}") | |
else: | |
print("β Gemini not configured (GOOGLE_API_KEY missing)") | |
# Hugging Face Inference Providers (if configured) | |
if settings.HF_API_KEY: | |
db = SessionLocal() | |
try: | |
models = crud.get_models(db) | |
registered = 0 | |
skipped = 0 | |
for m in models: | |
# Only register HF rows; skip βlogicalβ names that map to other providers | |
if ( | |
getattr(m, "provider", "") == "huggingface" | |
and getattr(m, "model_id", None) | |
and m.m_code not in {"STUB_MODEL", "GPT-4O", "GEMINI15"} | |
): | |
try: | |
svc = ProvidersGenericVLMService( | |
api_key=settings.HF_API_KEY, | |
model_id=m.model_id, | |
public_name=m.m_code, # stable name your UI/DB uses | |
) | |
vlm_manager.register_service(svc) | |
print(f"β HF registered: {m.m_code} -> {m.model_id}") | |
registered += 1 | |
except Exception as e: | |
print(f"β HF model {m.m_code} failed to register: {e}") | |
else: | |
skipped += 1 | |
if registered: | |
print(f"β Hugging Face services registered: {registered}") | |
else: | |
print("β No Hugging Face models registered (none found or all skipped)") | |
if skipped: | |
print(f"βΉ HF skipped entries: {skipped}") | |
finally: | |
db.close() | |
else: | |
print("β Hugging Face not configured (HF_API_KEY missing)") | |
# Kick off lightweight probes in the background (donβt block startup) | |
try: | |
asyncio.create_task(vlm_manager.probe_all()) | |
except Exception as e: | |
print(f"Probe scheduling failed: {e}") | |
print(f"β Available models now: {', '.join(vlm_manager.get_available_models())}") | |
print(f"β Total services: {len(vlm_manager.services)}") | |
# Run startup tasks | |
run_migrations() | |
ensure_storage_ready() | |
print("PromptAid Vision API server ready") | |
print("Endpoints: /api/images, /api/captions, /api/metadata, /api/models") | |
print(f"Environment: {settings.ENVIRONMENT}") | |
print("CORS: localhost + *.hf.space") | |