File size: 2,927 Bytes
7c7ef49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Literal, Optional, Dict, Any, Union
from backend.utils import async_generate_text_and_image, async_generate_with_image_input
from backend.category_config import CATEGORY_CONFIGS
from backend.logging_utils import log_category_usage, get_category_statistics
import backend.config as config  # keep for reference if needed
import traceback

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

class TextGenerateRequest(BaseModel):
    prompt: str
    category: Optional[str] = None

class ImageTextGenerateRequest(BaseModel):
    text: Optional[str] = None
    image: str
    category: Optional[str] = None

class Part(BaseModel):
    type: Literal["text", "image"]
    data: Union[str, Dict[str, str]]  # Can be either a string (for image) or dict (for text)

class GenerationResponse(BaseModel):
    results: List[Part]

@app.post("/generate", response_model=GenerationResponse)
async def generate(request: TextGenerateRequest):
    """
    Generate text and image from a text prompt with optional category.
    """
    success = False
    try:
        results = []
        async for part in async_generate_text_and_image(request.prompt, request.category):
            results.append(part)
        success = True
        return GenerationResponse(results=results)
    except Exception as e:
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"Internal error: {e}")
    finally:
        log_category_usage(request.category, "/generate", success)

@app.post("/generate_with_image", response_model=GenerationResponse)
async def generate_with_image(request: ImageTextGenerateRequest):
    """
    Generate text and image given a text and base64 image with optional category.
    """
    success = False
    try:
        results = []
        text = request.text if request.text else config.DEFAULT_TEXT
        async for part in async_generate_with_image_input(text, request.image, request.category):
            results.append(part)
        success = True
        return GenerationResponse(results=results)
    except Exception as e:
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"Internal error: {e}")
    finally:
        log_category_usage(request.category, "/generate_with_image", success)

@app.get("/categories")
async def get_categories():
    """
    Get all available engineering categories with their descriptions and configurations.
    """
    return CATEGORY_CONFIGS

@app.get("/category-stats")
async def get_usage_statistics():
    """
    Get usage statistics for all categories.
    """
    return get_category_statistics()

@app.get("/")
async def read_root():
    return {"message": "Image generation API is up"}