Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from typing import List, Literal, Optional, Dict, Any, Union | |
from backend.utils import async_generate_text_and_image, async_generate_with_image_input | |
from backend.category_config import CATEGORY_CONFIGS | |
from backend.logging_utils import log_category_usage, get_category_statistics | |
import backend.config as config # keep for reference if needed | |
import traceback | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class TextGenerateRequest(BaseModel): | |
prompt: str | |
category: Optional[str] = None | |
class ImageTextGenerateRequest(BaseModel): | |
text: Optional[str] = None | |
image: str | |
category: Optional[str] = None | |
class Part(BaseModel): | |
type: Literal["text", "image"] | |
data: Union[str, Dict[str, str]] # Can be either a string (for image) or dict (for text) | |
class GenerationResponse(BaseModel): | |
results: List[Part] | |
async def generate(request: TextGenerateRequest): | |
""" | |
Generate text and image from a text prompt with optional category. | |
""" | |
success = False | |
try: | |
results = [] | |
async for part in async_generate_text_and_image(request.prompt, request.category): | |
results.append(part) | |
success = True | |
return GenerationResponse(results=results) | |
except Exception as e: | |
traceback.print_exc() | |
raise HTTPException(status_code=500, detail=f"Internal error: {e}") | |
finally: | |
log_category_usage(request.category, "/generate", success) | |
async def generate_with_image(request: ImageTextGenerateRequest): | |
""" | |
Generate text and image given a text and base64 image with optional category. | |
""" | |
success = False | |
try: | |
results = [] | |
text = request.text if request.text else config.DEFAULT_TEXT | |
async for part in async_generate_with_image_input(text, request.image, request.category): | |
results.append(part) | |
success = True | |
return GenerationResponse(results=results) | |
except Exception as e: | |
traceback.print_exc() | |
raise HTTPException(status_code=500, detail=f"Internal error: {e}") | |
finally: | |
log_category_usage(request.category, "/generate_with_image", success) | |
async def get_categories(): | |
""" | |
Get all available engineering categories with their descriptions and configurations. | |
""" | |
return CATEGORY_CONFIGS | |
async def get_usage_statistics(): | |
""" | |
Get usage statistics for all categories. | |
""" | |
return get_category_statistics() | |
async def read_root(): | |
return {"message": "Image generation API is up"} |