from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import List, Literal, Optional, Dict, Any, Union from backend.utils import async_generate_text_and_image, async_generate_with_image_input from backend.category_config import CATEGORY_CONFIGS from backend.logging_utils import log_category_usage, get_category_statistics import backend.config as config # keep for reference if needed import traceback app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) class TextGenerateRequest(BaseModel): prompt: str category: Optional[str] = None class ImageTextGenerateRequest(BaseModel): text: Optional[str] = None image: str category: Optional[str] = None class Part(BaseModel): type: Literal["text", "image"] data: Union[str, Dict[str, str]] # Can be either a string (for image) or dict (for text) class GenerationResponse(BaseModel): results: List[Part] @app.post("/generate", response_model=GenerationResponse) async def generate(request: TextGenerateRequest): """ Generate text and image from a text prompt with optional category. """ success = False try: results = [] async for part in async_generate_text_and_image(request.prompt, request.category): results.append(part) success = True return GenerationResponse(results=results) except Exception as e: traceback.print_exc() raise HTTPException(status_code=500, detail=f"Internal error: {e}") finally: log_category_usage(request.category, "/generate", success) @app.post("/generate_with_image", response_model=GenerationResponse) async def generate_with_image(request: ImageTextGenerateRequest): """ Generate text and image given a text and base64 image with optional category. """ success = False try: results = [] text = request.text if request.text else config.DEFAULT_TEXT async for part in async_generate_with_image_input(text, request.image, request.category): results.append(part) success = True return GenerationResponse(results=results) except Exception as e: traceback.print_exc() raise HTTPException(status_code=500, detail=f"Internal error: {e}") finally: log_category_usage(request.category, "/generate_with_image", success) @app.get("/categories") async def get_categories(): """ Get all available engineering categories with their descriptions and configurations. """ return CATEGORY_CONFIGS @app.get("/category-stats") async def get_usage_statistics(): """ Get usage statistics for all categories. """ return get_category_statistics() @app.get("/") async def read_root(): return {"message": "Image generation API is up"}