|
""" |
|
API router for VQA endpoints |
|
""" |
|
import logging |
|
from typing import List, Optional |
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, BackgroundTasks, Request |
|
from fastapi.responses import JSONResponse |
|
from pydantic import BaseModel |
|
|
|
from app.services.session_service import SessionService |
|
from app.services.model_service import ModelService |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
router = APIRouter( |
|
prefix="/api/vqa", |
|
tags=["vqa"], |
|
) |
|
|
|
|
|
class QuestionRequest(BaseModel): |
|
"""Model for question request""" |
|
session_id: str |
|
question: str |
|
|
|
class AnswerResponse(BaseModel): |
|
"""Model for answer response""" |
|
answer: str |
|
answer_confidence: float |
|
is_answerable: bool |
|
answerable_confidence: float |
|
|
|
class SessionHistoryItem(BaseModel): |
|
"""Model for session history item""" |
|
question: str |
|
answer: AnswerResponse |
|
timestamp: str |
|
|
|
class SessionResponse(BaseModel): |
|
"""Model for session response""" |
|
session_id: str |
|
history: List[SessionHistoryItem] |
|
|
|
|
|
session_service = SessionService() |
|
|
|
@router.post("/upload", response_model=dict) |
|
async def upload_image( |
|
request: Request, |
|
file: UploadFile = File(...), |
|
background_tasks: BackgroundTasks = None |
|
): |
|
""" |
|
Upload an image and create a new session |
|
|
|
Args: |
|
file (UploadFile): The image file to upload |
|
|
|
Returns: |
|
dict: The session ID |
|
""" |
|
|
|
if not file.content_type.startswith("image/"): |
|
raise HTTPException(status_code=400, detail="File must be an image") |
|
|
|
try: |
|
|
|
session_id = session_service.create_session(file) |
|
|
|
return {"session_id": session_id} |
|
|
|
except Exception as e: |
|
logger.error(f"Error uploading image: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@router.post("/ask", response_model=AnswerResponse) |
|
async def ask_question( |
|
request: Request, |
|
question_request: QuestionRequest |
|
): |
|
""" |
|
Ask a question about the uploaded image |
|
|
|
Args: |
|
question_request (QuestionRequest): The question request |
|
|
|
Returns: |
|
AnswerResponse: The answer |
|
""" |
|
|
|
model_service = request.app.state.model_service |
|
|
|
|
|
session = session_service.get_session(question_request.session_id) |
|
if not session: |
|
raise HTTPException(status_code=404, detail="Session not found or expired") |
|
|
|
try: |
|
|
|
result = model_service.predict(session.image_path, question_request.question) |
|
|
|
|
|
session.add_question(question_request.question, result) |
|
|
|
return result |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing question: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@router.get("/session/{session_id}", response_model=SessionResponse) |
|
async def get_session( |
|
request: Request, |
|
session_id: str |
|
): |
|
""" |
|
Get session information including question history |
|
|
|
Args: |
|
session_id (str): The session ID |
|
|
|
Returns: |
|
SessionResponse: The session information |
|
""" |
|
|
|
session = session_service.get_session(session_id) |
|
if not session: |
|
raise HTTPException(status_code=404, detail="Session not found or expired") |
|
|
|
return { |
|
"session_id": session.session_id, |
|
"history": session.questions |
|
} |
|
|
|
@router.post("/session/{session_id}/complete") |
|
async def complete_session( |
|
request: Request, |
|
session_id: str |
|
): |
|
""" |
|
Mark a session as complete and clean up resources |
|
|
|
Args: |
|
session_id (str): The session ID |
|
|
|
Returns: |
|
dict: Success message |
|
""" |
|
|
|
session = session_service.get_session(session_id) |
|
if not session: |
|
raise HTTPException(status_code=404, detail="Session not found or expired") |
|
|
|
|
|
success = session_service.complete_session(session_id) |
|
|
|
if success: |
|
return {"message": "Session completed successfully, resources cleaned up"} |
|
else: |
|
raise HTTPException(status_code=500, detail="Failed to complete session") |
|
|
|
@router.delete("/session/{session_id}") |
|
async def reset_session( |
|
request: Request, |
|
session_id: str |
|
): |
|
""" |
|
Reset (delete) a session to start fresh |
|
|
|
Args: |
|
session_id (str): The session ID |
|
|
|
Returns: |
|
dict: Success message |
|
""" |
|
|
|
session = session_service.get_session(session_id) |
|
if not session: |
|
raise HTTPException(status_code=404, detail="Session not found or expired") |
|
|
|
|
|
session_service._remove_session(session_id) |
|
|
|
return {"message": "Session reset successfully"} |