dixisouls commited on
Commit
eacbbc9
·
1 Parent(s): cd7dd06

Initial Commit

Browse files
app/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Package initialization for app
3
+ """
app/config.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration settings for the application
3
+ """
4
+ import os
5
+ from pydantic_settings import BaseSettings
6
+ from dotenv import load_dotenv
7
+ from pathlib import Path
8
+
9
+ # Load .env file if it exists
10
+ load_dotenv()
11
+
12
+ class Settings(BaseSettings):
13
+ """Application settings"""
14
+ # App settings
15
+ APP_NAME: str = "VizWiz VQA API"
16
+ DEBUG: bool = os.getenv("DEBUG", "False").lower() == "true"
17
+
18
+ # Model settings
19
+ MODEL_PATH: str = os.getenv("MODEL_PATH", "./models/vqa_model_best.pt")
20
+ TEXT_MODEL: str = os.getenv("TEXT_MODEL", "bert-base-uncased")
21
+ VISION_MODEL: str = os.getenv("VISION_MODEL", "google/vit-base-patch16-384")
22
+ HUGGINGFACE_TOKEN: str = os.getenv("HUGGINGFACE_TOKEN", "")
23
+
24
+ # Hugging Face model repository settings
25
+ HF_MODEL_REPO: str = os.getenv("HF_MODEL_REPO", "dixisouls/VQA")
26
+ HF_MODEL_FILENAME: str = os.getenv("HF_MODEL_FILENAME", "model.pt")
27
+
28
+ # API settings
29
+ MAX_UPLOAD_SIZE: int = 10 * 1024 * 1024 # 10MB
30
+
31
+ # Storage settings
32
+ UPLOAD_DIR: str = os.getenv("UPLOAD_DIR", "./uploads")
33
+ MAX_SESSION_AGE: int = 60 * 30 # 30 minutes
34
+
35
+ # CORS settings
36
+ ALLOW_ORIGINS: list[str] = ["*"]
37
+
38
+ class Config:
39
+ env_file = ".env"
40
+ case_sensitive = True
41
+
42
+ # Global settings instance
43
+ settings = Settings()
44
+
45
+ # Ensure upload directory exists
46
+ Path(settings.UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
47
+
48
+ # Ensure models directory exists
49
+ Path(os.path.dirname(settings.MODEL_PATH)).mkdir(parents=True, exist_ok=True)
app/main.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main FastAPI application entry point
3
+ """
4
+ import os
5
+ import logging
6
+ from fastapi import FastAPI, HTTPException
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.staticfiles import StaticFiles
9
+ from contextlib import asynccontextmanager
10
+
11
+ from app.routers import vqa
12
+ from app.services.model_service import ModelService
13
+
14
+ # Configure logging
15
+ logging.basicConfig(
16
+ level=logging.INFO,
17
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
18
+ )
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Initialize model service in a lifespan context manager
22
+ @asynccontextmanager
23
+ async def lifespan(app: FastAPI):
24
+ # Load model on startup
25
+ logger.info("Loading VQA model...")
26
+ app.state.model_service = ModelService()
27
+ app.state.model_service.load_model()
28
+ logger.info("VQA model loaded successfully")
29
+ yield
30
+ # Clean up resources on shutdown
31
+ logger.info("Shutting down...")
32
+
33
+ # Initialize FastAPI app
34
+ app = FastAPI(
35
+ title="VizWiz VQA API",
36
+ description="API for Visual Question Answering on images",
37
+ version="1.0.0",
38
+ lifespan=lifespan
39
+ )
40
+
41
+ # Add CORS middleware
42
+ app.add_middleware(
43
+ CORSMiddleware,
44
+ allow_origins=["*"], # Allow all origins in development
45
+ allow_credentials=True,
46
+ allow_methods=["*"],
47
+ allow_headers=["*"],
48
+ )
49
+
50
+ # Mount static files directory if it exists
51
+ static_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static")
52
+ if os.path.exists(static_dir):
53
+ app.mount("/static", StaticFiles(directory=static_dir), name="static")
54
+
55
+ # Include routers
56
+ app.include_router(vqa.router)
57
+
58
+ # Health check endpoint
59
+ @app.get("/health")
60
+ async def health_check():
61
+ """Health check endpoint for monitoring the service"""
62
+ if not hasattr(app.state, "model_service") or not app.state.model_service.is_model_loaded():
63
+ raise HTTPException(status_code=503, detail="Model not loaded")
64
+ return {"status": "healthy", "model_loaded": True}
65
+
66
+ if __name__ == "__main__":
67
+ import uvicorn
68
+ uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)
app/models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Package initialization for app
3
+ """
app/models/vqa_model.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model implementation for VQA
3
+ """
4
+ import os
5
+ import json
6
+ import torch
7
+ import torch.nn as nn
8
+ from transformers import AutoTokenizer, AutoModel, AutoConfig, ViTImageProcessor, ViTModel
9
+
10
+ class VQAModel(nn.Module):
11
+ """Vision-Language model for Visual Question Answering"""
12
+ def __init__(self, config, num_answers):
13
+ super(VQAModel, self).__init__()
14
+ self.config = config
15
+ self.num_answers = num_answers
16
+
17
+ # Vision encoder
18
+ self.vision_config = AutoConfig.from_pretrained(config['vision_model'])
19
+ self.vision_encoder = ViTModel.from_pretrained(config['vision_model'])
20
+
21
+ # Text encoder
22
+ self.text_config = AutoConfig.from_pretrained(config['text_model'])
23
+ self.text_encoder = AutoModel.from_pretrained(config['text_model'])
24
+
25
+ # Projection layers
26
+ self.vision_projection = nn.Linear(
27
+ self.vision_config.hidden_size, config['hidden_size']
28
+ )
29
+ self.text_projection = nn.Linear(
30
+ self.text_config.hidden_size, config['hidden_size']
31
+ )
32
+
33
+ # Multimodal fusion
34
+ self.fusion = nn.Sequential(
35
+ nn.Linear(2 * config['hidden_size'], config['hidden_size']),
36
+ nn.LayerNorm(config['hidden_size']),
37
+ nn.GELU(),
38
+ nn.Dropout(config['dropout'])
39
+ )
40
+
41
+ # Answer prediction
42
+ self.classifier = nn.Sequential(
43
+ nn.Linear(config['hidden_size'], config['hidden_size']),
44
+ nn.LayerNorm(config['hidden_size']),
45
+ nn.GELU(),
46
+ nn.Dropout(config['dropout']),
47
+ nn.Linear(config['hidden_size'], num_answers)
48
+ )
49
+
50
+ # Answerable prediction
51
+ self.answerable_classifier = nn.Sequential(
52
+ nn.Linear(config['hidden_size'], config['hidden_size'] // 2),
53
+ nn.LayerNorm(config['hidden_size'] // 2),
54
+ nn.GELU(),
55
+ nn.Dropout(config['dropout']),
56
+ nn.Linear(config['hidden_size'] // 2, 2) # Binary classification
57
+ )
58
+
59
+ def forward(self, image_encodings, question_encodings):
60
+ """Forward pass of the model"""
61
+ # Process image
62
+ vision_outputs = self.vision_encoder(**image_encodings)
63
+ vision_embeds = vision_outputs.last_hidden_state[:, 0] # CLS token
64
+ vision_embeds = self.vision_projection(vision_embeds)
65
+
66
+ # Process text
67
+ text_outputs = self.text_encoder(**question_encodings)
68
+ text_embeds = text_outputs.last_hidden_state[:, 0] # CLS token
69
+ text_embeds = self.text_projection(text_embeds)
70
+
71
+ # Combine modalities
72
+ multimodal_features = torch.cat([vision_embeds, text_embeds], dim=1)
73
+ fused_features = self.fusion(multimodal_features)
74
+
75
+ # Predict answers and answerable
76
+ answer_logits = self.classifier(fused_features)
77
+ answerable_logits = self.answerable_classifier(fused_features)
78
+
79
+ return {
80
+ 'answer_logits': answer_logits,
81
+ 'answerable_logits': answerable_logits,
82
+ 'fused_features': fused_features
83
+ }
app/routers/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Package initialization for app
3
+ """
app/routers/vqa.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API router for VQA endpoints
3
+ """
4
+ import logging
5
+ from typing import List, Optional
6
+ from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, BackgroundTasks, Request
7
+ from fastapi.responses import JSONResponse
8
+ from pydantic import BaseModel
9
+
10
+ from app.services.session_service import SessionService
11
+ from app.services.model_service import ModelService
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Initialize router
16
+ router = APIRouter(
17
+ prefix="/api/vqa",
18
+ tags=["vqa"],
19
+ )
20
+
21
+ # Models for request/response
22
+ class QuestionRequest(BaseModel):
23
+ """Model for question request"""
24
+ session_id: str
25
+ question: str
26
+
27
+ class AnswerResponse(BaseModel):
28
+ """Model for answer response"""
29
+ answer: str
30
+ answer_confidence: float
31
+ is_answerable: bool
32
+ answerable_confidence: float
33
+
34
+ class SessionHistoryItem(BaseModel):
35
+ """Model for session history item"""
36
+ question: str
37
+ answer: AnswerResponse
38
+ timestamp: str
39
+
40
+ class SessionResponse(BaseModel):
41
+ """Model for session response"""
42
+ session_id: str
43
+ history: List[SessionHistoryItem]
44
+
45
+ # Dependency for services
46
+ session_service = SessionService()
47
+
48
+ @router.post("/upload", response_model=dict)
49
+ async def upload_image(
50
+ request: Request,
51
+ file: UploadFile = File(...),
52
+ background_tasks: BackgroundTasks = None
53
+ ):
54
+ """
55
+ Upload an image and create a new session
56
+
57
+ Args:
58
+ file (UploadFile): The image file to upload
59
+
60
+ Returns:
61
+ dict: The session ID
62
+ """
63
+ # Validate image file
64
+ if not file.content_type.startswith("image/"):
65
+ raise HTTPException(status_code=400, detail="File must be an image")
66
+
67
+ try:
68
+ # Create a new session
69
+ session_id = session_service.create_session(file)
70
+
71
+ return {"session_id": session_id}
72
+
73
+ except Exception as e:
74
+ logger.error(f"Error uploading image: {e}")
75
+ raise HTTPException(status_code=500, detail=str(e))
76
+
77
+ @router.post("/ask", response_model=AnswerResponse)
78
+ async def ask_question(
79
+ request: Request,
80
+ question_request: QuestionRequest
81
+ ):
82
+ """
83
+ Ask a question about the uploaded image
84
+
85
+ Args:
86
+ question_request (QuestionRequest): The question request
87
+
88
+ Returns:
89
+ AnswerResponse: The answer
90
+ """
91
+ # Get the model service from app state
92
+ model_service = request.app.state.model_service
93
+
94
+ # Get the session
95
+ session = session_service.get_session(question_request.session_id)
96
+ if not session:
97
+ raise HTTPException(status_code=404, detail="Session not found or expired")
98
+
99
+ try:
100
+ # Make prediction
101
+ result = model_service.predict(session.image_path, question_request.question)
102
+
103
+ # Add to session history
104
+ session.add_question(question_request.question, result)
105
+
106
+ return result
107
+
108
+ except Exception as e:
109
+ logger.error(f"Error processing question: {e}")
110
+ raise HTTPException(status_code=500, detail=str(e))
111
+
112
+ @router.get("/session/{session_id}", response_model=SessionResponse)
113
+ async def get_session(
114
+ request: Request,
115
+ session_id: str
116
+ ):
117
+ """
118
+ Get session information including question history
119
+
120
+ Args:
121
+ session_id (str): The session ID
122
+
123
+ Returns:
124
+ SessionResponse: The session information
125
+ """
126
+ # Get the session
127
+ session = session_service.get_session(session_id)
128
+ if not session:
129
+ raise HTTPException(status_code=404, detail="Session not found or expired")
130
+
131
+ return {
132
+ "session_id": session.session_id,
133
+ "history": session.questions
134
+ }
135
+
136
+ @router.post("/session/{session_id}/complete")
137
+ async def complete_session(
138
+ request: Request,
139
+ session_id: str
140
+ ):
141
+ """
142
+ Mark a session as complete and clean up resources
143
+
144
+ Args:
145
+ session_id (str): The session ID
146
+
147
+ Returns:
148
+ dict: Success message
149
+ """
150
+ # Check if session exists
151
+ session = session_service.get_session(session_id)
152
+ if not session:
153
+ raise HTTPException(status_code=404, detail="Session not found or expired")
154
+
155
+ # Complete the session (delete image but keep session data temporarily)
156
+ success = session_service.complete_session(session_id)
157
+
158
+ if success:
159
+ return {"message": "Session completed successfully, resources cleaned up"}
160
+ else:
161
+ raise HTTPException(status_code=500, detail="Failed to complete session")
162
+
163
+ @router.delete("/session/{session_id}")
164
+ async def reset_session(
165
+ request: Request,
166
+ session_id: str
167
+ ):
168
+ """
169
+ Reset (delete) a session to start fresh
170
+
171
+ Args:
172
+ session_id (str): The session ID
173
+
174
+ Returns:
175
+ dict: Success message
176
+ """
177
+ # Check if session exists
178
+ session = session_service.get_session(session_id)
179
+ if not session:
180
+ raise HTTPException(status_code=404, detail="Session not found or expired")
181
+
182
+ # Remove the session
183
+ session_service._remove_session(session_id)
184
+
185
+ return {"message": "Session reset successfully"}
app/services/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Package initialization for app
3
+ """
app/services/model_service.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model service for handling VQA model operations
3
+ """
4
+ import os
5
+ import json
6
+ import logging
7
+ import torch
8
+ from PIL import Image
9
+ from transformers import AutoTokenizer, ViTImageProcessor
10
+ from huggingface_hub import hf_hub_download, login
11
+
12
+ from app.config import settings
13
+ from app.models.vqa_model import VQAModel
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class ModelService:
18
+ """Service for loading and running the VQA model"""
19
+
20
+ def __init__(self):
21
+ """Initialize the model service"""
22
+ self.model = None
23
+ self.processor = None
24
+ self.tokenizer = None
25
+ self.config = None
26
+ self.answer_vocab = None
27
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ logger.info(f"Using device: {self.device}")
29
+
30
+ # Try to login to Hugging Face if token is provided
31
+ if settings.HUGGINGFACE_TOKEN:
32
+ try:
33
+ login(token=settings.HUGGINGFACE_TOKEN)
34
+ logger.info("Successfully logged in to Hugging Face Hub")
35
+ except Exception as e:
36
+ logger.error(f"Error logging in to Hugging Face Hub: {e}")
37
+
38
+ def _check_model_exists(self):
39
+ """Check if the model file exists locally"""
40
+ return os.path.exists(settings.MODEL_PATH)
41
+
42
+ def _download_model_from_hub(self):
43
+ """Download the model from Hugging Face Hub if not present locally"""
44
+ try:
45
+ # Create the directory if it doesn't exist
46
+ os.makedirs(os.path.dirname(settings.MODEL_PATH), exist_ok=True)
47
+
48
+ logger.info(f"Downloading model from {settings.HF_MODEL_REPO} to {settings.MODEL_PATH}")
49
+
50
+ # Download the model file from Hugging Face
51
+ hf_hub_download(
52
+ repo_id=settings.HF_MODEL_REPO,
53
+ filename=settings.HF_MODEL_FILENAME,
54
+ local_dir=os.path.dirname(settings.MODEL_PATH),
55
+ local_dir_use_symlinks=False
56
+ )
57
+
58
+ # Rename the downloaded file to match the expected path if needed
59
+ downloaded_path = os.path.join(os.path.dirname(settings.MODEL_PATH), settings.HF_MODEL_FILENAME)
60
+ if downloaded_path != settings.MODEL_PATH:
61
+ os.rename(downloaded_path, settings.MODEL_PATH)
62
+
63
+ logger.info(f"Model downloaded successfully to {settings.MODEL_PATH}")
64
+ return True
65
+ except Exception as e:
66
+ logger.error(f"Error downloading model from Hugging Face Hub: {e}")
67
+ return False
68
+
69
+ def load_model(self):
70
+ """Load the VQA model from the specified path or download it if not present"""
71
+ try:
72
+ # Check if model exists locally
73
+ if not self._check_model_exists():
74
+ logger.info(f"Model not found at {settings.MODEL_PATH}")
75
+
76
+ # Download the model from Hugging Face Hub
77
+ if not self._download_model_from_hub():
78
+ logger.error("Failed to download model from Hugging Face Hub")
79
+ return False
80
+
81
+ logger.info(f"Loading model from {settings.MODEL_PATH}")
82
+ checkpoint = torch.load(settings.MODEL_PATH, map_location=self.device)
83
+
84
+ # Extract configuration
85
+ self.config = checkpoint['config']
86
+
87
+ # Get vocabulary
88
+ if 'answer_vocab' in checkpoint:
89
+ self.answer_vocab = checkpoint['answer_vocab']
90
+ logger.info("Using vocabulary from model checkpoint")
91
+ else:
92
+ logger.error("Error: No vocabulary found in model checkpoint")
93
+ raise ValueError("No vocabulary found in model checkpoint")
94
+
95
+ # Initialize model
96
+ self.model = VQAModel(self.config, len(self.answer_vocab['answer_to_idx']))
97
+ self.model.load_state_dict(checkpoint['model_state_dict'])
98
+ self.model.to(self.device)
99
+ self.model.eval()
100
+
101
+ # Initialize preprocessors
102
+ self.processor = ViTImageProcessor.from_pretrained(self.config['vision_model'])
103
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config['text_model'])
104
+
105
+ logger.info("Model loaded successfully")
106
+ return True
107
+
108
+ except Exception as e:
109
+ logger.error(f"Error loading model: {e}")
110
+ return False
111
+
112
+ def is_model_loaded(self):
113
+ """Check if the model is loaded"""
114
+ return self.model is not None and self.processor is not None and self.tokenizer is not None
115
+
116
+ def predict(self, image_path, question):
117
+ """
118
+ Make a prediction for the given image and question
119
+
120
+ Args:
121
+ image_path (str): Path to the image file
122
+ question (str): Question about the image
123
+
124
+ Returns:
125
+ dict: Prediction results
126
+ """
127
+ if not self.is_model_loaded():
128
+ logger.error("Model not loaded")
129
+ raise RuntimeError("Model not loaded")
130
+
131
+ try:
132
+ # Preprocess image
133
+ image = Image.open(image_path).convert('RGB')
134
+ image_encoding = self.processor(images=image, return_tensors="pt")
135
+ image_encoding = {k: v.to(self.device) for k, v in image_encoding.items()}
136
+
137
+ # Preprocess question
138
+ question_encoding = self.tokenizer(
139
+ question,
140
+ padding='max_length',
141
+ truncation=True,
142
+ max_length=128,
143
+ return_tensors='pt'
144
+ )
145
+ question_encoding = {k: v.to(self.device) for k, v in question_encoding.items()}
146
+
147
+ # Get predictions
148
+ with torch.no_grad():
149
+ outputs = self.model(image_encoding, question_encoding)
150
+
151
+ answer_logits = outputs['answer_logits']
152
+ answerable_logits = outputs['answerable_logits']
153
+
154
+ answer_idx = torch.argmax(answer_logits, dim=1).item()
155
+ answerable_idx = torch.argmax(answerable_logits, dim=1).item()
156
+
157
+ # Convert string index to int for dictionary lookup
158
+ answer = self.answer_vocab['idx_to_answer'][str(answer_idx)]
159
+ is_answerable = bool(answerable_idx)
160
+
161
+ # Get confidence scores
162
+ answer_probs = torch.softmax(answer_logits, dim=1)[0]
163
+ answerable_probs = torch.softmax(answerable_logits, dim=1)[0]
164
+
165
+ answer_confidence = float(answer_probs[answer_idx].item())
166
+ answerable_confidence = float(answerable_probs[answerable_idx].item())
167
+
168
+ return {
169
+ 'answer': answer,
170
+ 'answer_confidence': answer_confidence,
171
+ 'is_answerable': is_answerable,
172
+ 'answerable_confidence': answerable_confidence
173
+ }
174
+
175
+ except Exception as e:
176
+ logger.error(f"Error during prediction: {e}")
177
+ raise
app/services/session_service.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import logging
4
+ import time
5
+ from datetime import datetime, timedelta
6
+ from typing import Dict, Optional, Tuple, List
7
+ from fastapi import UploadFile
8
+ from pathlib import Path
9
+
10
+ from app.config import settings
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class Session:
15
+ """Object representing a user session"""
16
+ def __init__(self, session_id: str, image_path: str):
17
+ self.session_id = session_id
18
+ self.image_path = image_path
19
+ self.created_at = datetime.now()
20
+ self.last_accessed = datetime.now()
21
+ self.questions = [] # History of questions for this session
22
+
23
+ def is_expired(self) -> bool:
24
+ """Check if the session has expired"""
25
+ expiry_time = self.last_accessed + timedelta(seconds=settings.MAX_SESSION_AGE)
26
+ return datetime.now() > expiry_time
27
+
28
+ def update_access_time(self):
29
+ """Update the last accessed time"""
30
+ self.last_accessed = datetime.now()
31
+
32
+ def add_question(self, question: str, answer: Dict):
33
+ """Add a question and its answer to the session history"""
34
+ self.questions.append({
35
+ "question": question,
36
+ "answer": answer,
37
+ "timestamp": datetime.now().isoformat()
38
+ })
39
+ self.update_access_time()
40
+
41
+ class SessionService:
42
+ """Service for managing user sessions"""
43
+
44
+ def __init__(self):
45
+ """Initialize the session service"""
46
+ self.sessions: Dict[str, Session] = {}
47
+ self.ensure_upload_dir()
48
+
49
+ # Start a background cleanup task
50
+ self._cleanup_sessions()
51
+
52
+ def ensure_upload_dir(self):
53
+ """Ensure the upload directory exists"""
54
+ os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
55
+
56
+ def create_session(self, file: UploadFile) -> str:
57
+ """
58
+ Create a new session for the user
59
+
60
+ Args:
61
+ file (UploadFile): The uploaded image file
62
+
63
+ Returns:
64
+ str: The session ID
65
+ """
66
+ # Generate a unique session ID
67
+ session_id = str(uuid.uuid4())
68
+
69
+ # Create a unique filename
70
+ timestamp = int(time.time())
71
+ file_extension = Path(file.filename).suffix
72
+ filename = f"{timestamp}_{session_id}{file_extension}"
73
+
74
+ # Save the uploaded file
75
+ file_path = os.path.join(settings.UPLOAD_DIR, filename)
76
+ with open(file_path, "wb") as f:
77
+ f.write(file.file.read())
78
+
79
+ # Create and store the session
80
+ self.sessions[session_id] = Session(session_id, file_path)
81
+
82
+ logger.info(f"Created new session {session_id} with image {file_path}")
83
+ return session_id
84
+
85
+ def get_session(self, session_id: str) -> Optional[Session]:
86
+ """
87
+ Get a session by ID
88
+
89
+ Args:
90
+ session_id (str): The session ID
91
+
92
+ Returns:
93
+ Optional[Session]: The session, or None if not found or expired
94
+ """
95
+ session = self.sessions.get(session_id)
96
+
97
+ if session is None:
98
+ return None
99
+
100
+ if session.is_expired():
101
+ self._remove_session(session_id)
102
+ return None
103
+
104
+ session.update_access_time()
105
+ return session
106
+
107
+ def complete_session(self, session_id: str) -> bool:
108
+ """
109
+ Mark a session as complete and remove its resources
110
+
111
+ Args:
112
+ session_id (str): The session ID
113
+
114
+ Returns:
115
+ bool: True if successful, False otherwise
116
+ """
117
+ session = self.sessions.get(session_id)
118
+ if not session:
119
+ logger.warning(f"Cannot complete nonexistent session: {session_id}")
120
+ return False
121
+
122
+ logger.info(f"Completing session {session_id}")
123
+
124
+ try:
125
+ # Remove the image file but keep session data temporarily for any final operations
126
+ if session.image_path and os.path.exists(session.image_path):
127
+ os.remove(session.image_path)
128
+ logger.info(f"Removed image file for completed session {session.image_path}")
129
+
130
+ # Set the image path to None to indicate it's been removed
131
+ session.image_path = None
132
+ return True
133
+ return True # No image to remove or already removed
134
+ except Exception as e:
135
+ logger.error(f"Error removing image file during session completion: {e}")
136
+ return False
137
+
138
+ def _remove_session(self, session_id: str):
139
+ """
140
+ Remove a session and its associated file
141
+
142
+ Args:
143
+ session_id (str): The session ID
144
+ """
145
+ session = self.sessions.pop(session_id, None)
146
+ if session:
147
+ try:
148
+ # Remove the image file
149
+ if session.image_path and os.path.exists(session.image_path):
150
+ os.remove(session.image_path)
151
+ logger.info(f"Removed session file {session.image_path}")
152
+ except Exception as e:
153
+ logger.error(f"Error removing session file: {e}")
154
+
155
+ def _cleanup_sessions(self):
156
+ """Clean up expired sessions"""
157
+ expired_sessions = [
158
+ session_id for session_id, session in self.sessions.items()
159
+ if session.is_expired()
160
+ ]
161
+
162
+ for session_id in expired_sessions:
163
+ self._remove_session(session_id)
164
+
165
+ if expired_sessions:
166
+ logger.info(f"Cleaned up {len(expired_sessions)} expired sessions")
app/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Package initialization for app
3
+ """
app/utils/image_utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for image processing
3
+ """
4
+ import os
5
+ import logging
6
+ from PIL import Image
7
+ import io
8
+ import base64
9
+ from typing import Tuple, Optional
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ def validate_image(image_path: str) -> bool:
14
+ """
15
+ Validate if a file is a valid image
16
+
17
+ Args:
18
+ image_path (str): Path to the image file
19
+
20
+ Returns:
21
+ bool: True if valid, False otherwise
22
+ """
23
+ try:
24
+ with Image.open(image_path) as img:
25
+ img.verify()
26
+ return True
27
+ except Exception as e:
28
+ logger.error(f"Image validation failed: {e}")
29
+ return False
30
+
31
+ def resize_image(image_path: str, max_size: Tuple[int, int] = (1024, 1024)) -> Optional[str]:
32
+ """
33
+ Resize an image if it's larger than max_size
34
+
35
+ Args:
36
+ image_path (str): Path to the image file
37
+ max_size (Tuple[int, int]): Maximum width and height
38
+
39
+ Returns:
40
+ Optional[str]: Path to the resized image or None if failed
41
+ """
42
+ try:
43
+ with Image.open(image_path) as img:
44
+ # Only resize if the image is larger than max_size
45
+ if img.width > max_size[0] or img.height > max_size[1]:
46
+ # Calculate new size while maintaining aspect ratio
47
+ ratio = min(max_size[0] / img.width, max_size[1] / img.height)
48
+ new_size = (int(img.width * ratio), int(img.height * ratio))
49
+
50
+ # Resize the image
51
+ resized_img = img.resize(new_size, Image.LANCZOS)
52
+
53
+ # Save the resized image
54
+ resized_path = os.path.splitext(image_path)[0] + "_resized" + os.path.splitext(image_path)[1]
55
+ resized_img.save(resized_path)
56
+ return resized_path
57
+
58
+ # No need to resize
59
+ return image_path
60
+
61
+ except Exception as e:
62
+ logger.error(f"Image resizing failed: {e}")
63
+ return None
64
+
65
+ def image_to_base64(image_path: str) -> Optional[str]:
66
+ """
67
+ Convert an image to base64 string
68
+
69
+ Args:
70
+ image_path (str): Path to the image file
71
+
72
+ Returns:
73
+ Optional[str]: Base64 encoded image string or None if failed
74
+ """
75
+ try:
76
+ with open(image_path, "rb") as image_file:
77
+ encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
78
+ return encoded_string
79
+ except Exception as e:
80
+ logger.error(f"Base64 conversion failed: {e}")
81
+ return None