File size: 5,036 Bytes
eacbbc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""
API router for VQA endpoints
"""
import logging
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, BackgroundTasks, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel

from app.services.session_service import SessionService
from app.services.model_service import ModelService

logger = logging.getLogger(__name__)

# Initialize router
router = APIRouter(
    prefix="/api/vqa",
    tags=["vqa"],
)

# Models for request/response
class QuestionRequest(BaseModel):
    """Model for question request"""
    session_id: str
    question: str

class AnswerResponse(BaseModel):
    """Model for answer response"""
    answer: str
    answer_confidence: float
    is_answerable: bool
    answerable_confidence: float

class SessionHistoryItem(BaseModel):
    """Model for session history item"""
    question: str
    answer: AnswerResponse
    timestamp: str

class SessionResponse(BaseModel):
    """Model for session response"""
    session_id: str
    history: List[SessionHistoryItem]

# Dependency for services
session_service = SessionService()

@router.post("/upload", response_model=dict)
async def upload_image(
    request: Request,
    file: UploadFile = File(...),
    background_tasks: BackgroundTasks = None
):
    """
    Upload an image and create a new session
    
    Args:
        file (UploadFile): The image file to upload
        
    Returns:
        dict: The session ID
    """
    # Validate image file
    if not file.content_type.startswith("image/"):
        raise HTTPException(status_code=400, detail="File must be an image")
    
    try:
        # Create a new session
        session_id = session_service.create_session(file)
        
        return {"session_id": session_id}
    
    except Exception as e:
        logger.error(f"Error uploading image: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@router.post("/ask", response_model=AnswerResponse)
async def ask_question(
    request: Request,
    question_request: QuestionRequest
):
    """
    Ask a question about the uploaded image
    
    Args:
        question_request (QuestionRequest): The question request
        
    Returns:
        AnswerResponse: The answer
    """
    # Get the model service from app state
    model_service = request.app.state.model_service
    
    # Get the session
    session = session_service.get_session(question_request.session_id)
    if not session:
        raise HTTPException(status_code=404, detail="Session not found or expired")
    
    try:
        # Make prediction
        result = model_service.predict(session.image_path, question_request.question)
        
        # Add to session history
        session.add_question(question_request.question, result)
        
        return result
    
    except Exception as e:
        logger.error(f"Error processing question: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@router.get("/session/{session_id}", response_model=SessionResponse)
async def get_session(
    request: Request,
    session_id: str
):
    """
    Get session information including question history
    
    Args:
        session_id (str): The session ID
        
    Returns:
        SessionResponse: The session information
    """
    # Get the session
    session = session_service.get_session(session_id)
    if not session:
        raise HTTPException(status_code=404, detail="Session not found or expired")
    
    return {
        "session_id": session.session_id,
        "history": session.questions
    }

@router.post("/session/{session_id}/complete")
async def complete_session(
    request: Request,
    session_id: str
):
    """
    Mark a session as complete and clean up resources
    
    Args:
        session_id (str): The session ID
        
    Returns:
        dict: Success message
    """
    # Check if session exists
    session = session_service.get_session(session_id)
    if not session:
        raise HTTPException(status_code=404, detail="Session not found or expired")
    
    # Complete the session (delete image but keep session data temporarily)
    success = session_service.complete_session(session_id)
    
    if success:
        return {"message": "Session completed successfully, resources cleaned up"}
    else:
        raise HTTPException(status_code=500, detail="Failed to complete session")

@router.delete("/session/{session_id}")
async def reset_session(
    request: Request,
    session_id: str
):
    """
    Reset (delete) a session to start fresh
    
    Args:
        session_id (str): The session ID
        
    Returns:
        dict: Success message
    """
    # Check if session exists
    session = session_service.get_session(session_id)
    if not session:
        raise HTTPException(status_code=404, detail="Session not found or expired")
    
    # Remove the session
    session_service._remove_session(session_id)
    
    return {"message": "Session reset successfully"}