|
from fastapi import APIRouter, HTTPException |
|
|
|
from app.config import get_settings |
|
from app.core.errors import VendorError |
|
from app.schemas.requests import FollowSchemaRequest |
|
from app.services.factory import AIServiceFactory |
|
from app.utils.logger import exception_to_str, setup_logger |
|
|
|
logger = setup_logger(__name__) |
|
settings = get_settings() |
|
|
|
|
|
async def handle_follow(request: FollowSchemaRequest): |
|
|
|
request.max_attempts = max(request.max_attempts, 1) |
|
request.max_attempts = min(request.max_attempts, 5) |
|
|
|
for attempt in range(1, request.max_attempts + 1): |
|
try: |
|
logger.info(f"Attempt: {attempt}") |
|
if request.ai_model in settings.OPENAI_MODELS: |
|
ai_vendor = "openai" |
|
elif request.ai_model in settings.ANTHROPIC_MODELS: |
|
ai_vendor = "anthropic" |
|
else: |
|
raise ValueError( |
|
f"Invalid AI model: {request.ai_model}, only support {settings.SUPPORTED_MODELS}" |
|
) |
|
service = AIServiceFactory.get_service(ai_vendor) |
|
json_attributes = await service.follow_schema_with_validation( |
|
request.data_schema, request.data |
|
) |
|
break |
|
except ValueError as e: |
|
if attempt == request.max_attempts: |
|
raise HTTPException( |
|
status_code=400, |
|
detail=exception_to_str(e), |
|
headers={"attempt": attempt}, |
|
) |
|
except VendorError as e: |
|
if attempt == request.max_attempts: |
|
raise HTTPException( |
|
status_code=500, |
|
detail=exception_to_str(e), |
|
headers={"attempt": attempt}, |
|
) |
|
except Exception as e: |
|
if attempt == request.max_attempts: |
|
raise HTTPException( |
|
status_code=500, |
|
detail="Internal server error", |
|
headers={"attempt": attempt}, |
|
) |
|
|
|
return json_attributes, attempt |
|
|