samu's picture
1st
7c7ef49
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Literal, Optional, Dict, Any, Union
from backend.utils import async_generate_text_and_image, async_generate_with_image_input
from backend.category_config import CATEGORY_CONFIGS
from backend.logging_utils import log_category_usage, get_category_statistics
import backend.config as config # keep for reference if needed
import traceback
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class TextGenerateRequest(BaseModel):
prompt: str
category: Optional[str] = None
class ImageTextGenerateRequest(BaseModel):
text: Optional[str] = None
image: str
category: Optional[str] = None
class Part(BaseModel):
type: Literal["text", "image"]
data: Union[str, Dict[str, str]] # Can be either a string (for image) or dict (for text)
class GenerationResponse(BaseModel):
results: List[Part]
@app.post("/generate", response_model=GenerationResponse)
async def generate(request: TextGenerateRequest):
"""
Generate text and image from a text prompt with optional category.
"""
success = False
try:
results = []
async for part in async_generate_text_and_image(request.prompt, request.category):
results.append(part)
success = True
return GenerationResponse(results=results)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Internal error: {e}")
finally:
log_category_usage(request.category, "/generate", success)
@app.post("/generate_with_image", response_model=GenerationResponse)
async def generate_with_image(request: ImageTextGenerateRequest):
"""
Generate text and image given a text and base64 image with optional category.
"""
success = False
try:
results = []
text = request.text if request.text else config.DEFAULT_TEXT
async for part in async_generate_with_image_input(text, request.image, request.category):
results.append(part)
success = True
return GenerationResponse(results=results)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Internal error: {e}")
finally:
log_category_usage(request.category, "/generate_with_image", success)
@app.get("/categories")
async def get_categories():
"""
Get all available engineering categories with their descriptions and configurations.
"""
return CATEGORY_CONFIGS
@app.get("/category-stats")
async def get_usage_statistics():
"""
Get usage statistics for all categories.
"""
return get_category_statistics()
@app.get("/")
async def read_root():
return {"message": "Image generation API is up"}