Spaces:
Running
Running
# app/routers/chat.py | |
import logging | |
import os | |
import json | |
import tempfile | |
from datetime import datetime | |
from bson import ObjectId | |
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, Header | |
from fastapi.responses import JSONResponse, FileResponse | |
from pydantic import BaseModel | |
from app.database.database_query import DatabaseQuery | |
from app.middleware.auth import get_current_user, get_optional_user | |
from app.services import ChatProcessor | |
from app.services.image_processor import ImageProcessor | |
from app.services.report_process import Report | |
from app.services.skincare_scheduler import SkinCareScheduler | |
from app.services.wheel import EnvironmentalConditions | |
from app.services.RAG_evaluation import RAGEvaluation | |
router = APIRouter() | |
query = DatabaseQuery() | |
class ChatSessionTitleUpdate(BaseModel): | |
title: str | |
async def serve_image(filename: str): | |
try: | |
# Use an absolute path or environment variable to ensure consistency | |
upload_dir = os.path.abspath('uploads') | |
file_path = os.path.join(upload_dir, filename) | |
# Add logging to debug | |
print(f"Attempting to serve file from: {file_path}") | |
if not os.path.exists(file_path): | |
print(f"File not found: {file_path}") | |
raise FileNotFoundError() | |
return FileResponse(file_path) | |
except FileNotFoundError: | |
raise HTTPException(status_code=404, detail="Image not found") | |
async def create_chat_session(username: str = Depends(get_current_user)): | |
try: | |
session_id = str(ObjectId()) | |
chat_session = { | |
"user_id": username, | |
"session_id": session_id, | |
"created_at": datetime.utcnow(), | |
"last_accessed": datetime.utcnow(), | |
"title": "New Chat" | |
} | |
query.create_chat_session(chat_session) | |
return {"message": "Chat session created", "session_id": session_id} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_user_chat_sessions(username: str = Depends(get_current_user)): | |
try: | |
sessions = query.get_user_chat_sessions(username) | |
return sessions | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def delete_chat_session(session_id: str, username: str = Depends(get_current_user)): | |
try: | |
result = query.delete_chat_session(session_id, username) | |
if result["session_deleted"]: | |
return { | |
"message": "Chat session and associated chats deleted successfully", | |
"chats_deleted": result["chats_deleted"] | |
} | |
raise HTTPException(status_code=404, detail="Chat session not found or unauthorized") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def update_chat_title( | |
session_id: str, | |
title_data: ChatSessionTitleUpdate, | |
username: str = Depends(get_current_user) | |
): | |
try: | |
new_title = title_data.title | |
if not query.verify_session(session_id, username): | |
raise HTTPException(status_code=404, detail="Chat session not found or unauthorized") | |
if query.update_chat_session_title(session_id, new_title): | |
return { | |
'message': 'Chat session title updated successfully', | |
'session_id': session_id, | |
'new_title': new_title | |
} | |
raise HTTPException(status_code=500, detail="Failed to update chat session title") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def delete_all_sessions_and_chats(username: str = Depends(get_current_user)): | |
try: | |
result = query.delete_all_user_sessions_and_chats(username) | |
return { | |
"message": "Successfully deleted all chat sessions and chats", | |
"deleted_chats": result["deleted_chats"], | |
"deleted_sessions": result["deleted_sessions"] | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_session_chats(session_id: str, username: str = Depends(get_current_user)): | |
try: | |
chats = query.get_session_chats(session_id, username) | |
return chats | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def export_chat(session_id: str, username: str = Depends(get_current_user)): | |
try: | |
if not query.verify_session(session_id, username): | |
raise HTTPException(status_code=404, detail="Chat session not found or unauthorized") | |
chats = query.get_session_chats(session_id, username) | |
formatted_chats = [] | |
for chat in chats: | |
formatted_chat = { | |
'query': chat.get('query', ''), | |
'response': chat.get('response', ''), | |
'references': chat.get('references', []), | |
'page_no': chat.get('page_no', ''), | |
'date': chat.get('timestamp', ''), | |
'chat_id': chat.get('chat_id', '') | |
} | |
formatted_chats.append(formatted_chat) | |
export_data = { | |
'session_id': session_id, | |
'export_date': datetime.utcnow().isoformat(), | |
'chats': formatted_chats | |
} | |
return export_data | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def export_all_chats(username: str = Depends(get_current_user)): | |
try: | |
all_chats = query.get_all_user_chats(username) | |
formatted_sessions = [] | |
for session in all_chats: | |
formatted_chats = [] | |
for chat in session['chats']: | |
formatted_chat = { | |
'query': chat.get('query', ''), | |
'response': chat.get('response', ''), | |
'references': chat.get('references', []), | |
'page_no': chat.get('page_no', ''), | |
'timestamp': chat.get('timestamp', ''), | |
'chat_id': chat.get('chat_id', '') | |
} | |
formatted_chats.append(formatted_chat) | |
formatted_session = { | |
'session_id': session['session_id'], | |
'title': session['title'], | |
'created_at': session['created_at'], | |
'last_accessed': session['last_accessed'], | |
'chats': formatted_chats | |
} | |
formatted_sessions.append(formatted_session) | |
export_data = { | |
'user': username, | |
'export_date': datetime.utcnow().isoformat(), | |
'sessions': formatted_sessions | |
} | |
return export_data | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def web_search( | |
data: dict, | |
authorization: str = Header(None), | |
username: str = Depends(get_current_user) | |
): | |
try: | |
token = authorization.split(" ")[1] | |
session_id = data.get("session_id") | |
query = data.get("query") | |
num_results = data.get("num_results", 3) | |
num_images = data.get("num_images", 3) | |
if not session_id or not query: | |
return JSONResponse( | |
status_code=400, | |
content={"error": "session_id and query are required"} | |
) | |
chat_processor = ChatProcessor(token=token, session_id=session_id, num_results=num_results, num_images=num_images) | |
response = chat_processor.web_search(query=query) | |
return {"response": response} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def analyze_report( | |
file: UploadFile = File(...), | |
query: str = Form(...), | |
session_id: str = Form(...), | |
authorization: str = Header(None), | |
username: str = Depends(get_current_user) | |
): | |
try: | |
token = authorization.split(" ")[1] | |
if not file.filename: | |
return JSONResponse( | |
status_code=400, | |
content={"status": "error", "error": "Empty file provided"} | |
) | |
if not query.strip(): | |
return JSONResponse( | |
status_code=400, | |
content={"status": "error", "error": "Query is required"} | |
) | |
file_extension = file.filename.rsplit('.', 1)[1].lower() if '.' in file.filename else '' | |
allowed_extensions = { | |
'pdf': 'pdf', | |
'xlsx': 'excel', | |
'xls': 'excel', | |
'csv': 'csv', | |
'jpg': 'image', | |
'jpeg': 'image', | |
'png': 'image', | |
'doc': 'word', | |
'docx': 'word', | |
'ppt': 'ppt', | |
'txt': 'text', | |
'html': 'html' | |
} | |
if file_extension not in allowed_extensions: | |
return JSONResponse( | |
status_code=200, | |
content={ | |
"status": "success", | |
"message": f"Unsupported file type. Allowed types: {', '.join(allowed_extensions.keys())}", | |
"analysis": result | |
} | |
) | |
temp_dir = tempfile.mkdtemp() | |
temp_file_path = os.path.join(temp_dir, file.filename) | |
try: | |
content = await file.read() | |
with open(temp_file_path, "wb") as f: | |
f.write(content) | |
processor = Report(token=token, session_id=session_id) | |
result = processor.process_chat( | |
query=query, | |
report_file=temp_file_path, | |
file_type=allowed_extensions[file_extension] | |
) | |
return { | |
"status": "success", | |
"message": "Report analyzed successfully", | |
"analysis": result | |
} | |
finally: | |
# Clean up temporary files | |
if os.path.exists(temp_file_path): | |
os.remove(temp_file_path) | |
os.rmdir(temp_dir) | |
except Exception as e: | |
logging.error(f"Error in analyze_report: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail={ | |
"status": "error", | |
"error": "Internal server error", | |
"details": str(e) | |
} | |
) | |
async def get_skin_care_schedule( | |
authorization: str = Header(None), | |
username: str = Depends(get_current_user) | |
): | |
try: | |
token = authorization.split(" ")[1] | |
scheduler = SkinCareScheduler(token, "session_id") | |
schedule = scheduler.createTable() | |
return json.loads(schedule) | |
except Exception as e: | |
logging.error(f"Error generating skin care schedule: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail={"error": "Failed to generate skin care schedule"} | |
) | |
async def get_skin_care_wheel( | |
authorization: str = Header(...), | |
username: str = Depends(get_current_user) | |
): | |
try: | |
token = authorization.split(" ")[1] | |
condition = EnvironmentalConditions(session_id=token) | |
condition_data = condition.get_conditon() | |
return condition_data | |
except Exception as e: | |
logging.error(f"Error generating skin care wheel: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail={ | |
"error": "Failed to generate skin care wheel", | |
"message": "An unexpected error occurred" | |
} | |
) | |
async def disease_search( | |
session_id: str = Form(...), | |
query: str = Form(...), | |
num_results: int = Form(3), | |
num_images: int = Form(3), | |
image: UploadFile = File(...), | |
authorization: str = Header(...), | |
username: str = Depends(get_current_user) | |
): | |
try: | |
token = authorization.split(" ")[1] | |
image_processor = ImageProcessor( | |
token=token, | |
session_id=session_id, | |
num_results=num_results, | |
num_images=num_images, | |
image=image | |
) | |
response = image_processor.web_search(query=query) | |
return {"response": response} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def rag_evaluation( | |
page: int = Form(3), | |
page_size: int = Form(3), | |
authorization: str = Header(...), | |
username: str = Depends(get_current_user) | |
): | |
try: | |
token = authorization.split(" ")[1] | |
evaluator = RAGEvaluation( | |
token=token, | |
page=page, | |
page_size=page_size | |
) | |
report = evaluator.generate_evaluation_report() | |
return {"response": report} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) |