import asyncio import json from uuid import uuid4 from celery.result import AsyncResult from app.core.tasks import process_documents, redis_client from app.core.tasks import generate_response from app.backend.controllers.messages import register_message from app.core.document_validator import path_is_valid from app.core.response_parser import add_links from app.settings import BASE_DIR, settings, logger, app from app.backend.controllers.chats import ( get_chat_with_messages, create_new_chat, update_title, ) from app.backend.controllers.users import ( extract_user_from_context, get_current_user, get_latest_chat, refresh_cookie, authorize_user, check_cookie, create_user ) from app.core.utils import ( construct_collection_name, create_collection, extend_context, initialize_rag, save_documents, protect_chat, TextHandler, PDFHandler, ) from fastapi.middleware.cors import CORSMiddleware from fastapi.templating import Jinja2Templates from fastapi.staticfiles import StaticFiles from fastapi import ( HTTPException, UploadFile, Request, FastAPI, Form, File, WebSocket ) from fastapi.responses import ( StreamingResponse, RedirectResponse, FileResponse, JSONResponse, ) from typing import Optional import aiofiles import os # <------------------------------------- API -------------------------------------> api = FastAPI() rag = initialize_rag() api.mount( "/chats_storage", StaticFiles(directory=os.path.join(BASE_DIR, "chats_storage")), name="chats_storage", ) api.mount( "/static", StaticFiles(directory=os.path.join(BASE_DIR, "app", "frontend", "static")), name="static", ) templates = Jinja2Templates( directory=os.path.join(BASE_DIR, "app", "frontend", "templates") ) # <--------------------------------- Middleware ---------------------------------> api.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @api.middleware("http") async def require_user(request: Request, call_next): if settings.debug: await logger.info("START MIDDLEWARE") try: if settings.debug: await logger.info(f"Path ----> {request.url.path}, Method ----> {request.method}, Port ----> {request.url.port}") stripped_path = request.url.path.strip("/") if ( stripped_path.startswith("pdfs") or "static/styles.css" in stripped_path or "favicon.ico" in stripped_path ): return await call_next(request) user = await get_current_user(request) if settings.debug: await logger.info(f"User: {user}") authorized = True if user is None: authorized = False user = await create_user() if settings.debug: await logger.info(f"User in Context ----> {user.id}") request.state.current_user = user response = await call_next(request) if authorized: await refresh_cookie(request=request, response=response) else: await authorize_user(response, user) return response except Exception as exception: raise exception finally: if settings.debug: await logger.info("END MIDDLEWARE") # <--------------------------------- Common routes ---------------------------------> @api.post("/message_with_docs") async def send_message(request: Request, files: list[UploadFile] = File(None), prompt: str = Form(...), chat_id: str = Form(None)) -> StreamingResponse: status = 200 try: await logger.info("Start processing the message") user = await extract_user_from_context(request) if settings.debug: await logger.info(f" User ----> {user}") collection_name = await construct_collection_name(user, chat_id) if settings.debug: await logger.info(f"Received message -------> {prompt}") await register_message(content=prompt, sender="user", chat_id=chat_id) docs = await save_documents( files=files, user=user, chat_id=chat_id ) doc_task = None if docs: doc_task = process_documents.delay( collection_name=collection_name, files=docs, chat_id=chat_id, ) task_id = str(uuid4()) resp_task = generate_response.delay( collection_name=collection_name, prompt=prompt, chat_id=chat_id, task_id=task_id ) return JSONResponse({ "doc_task_id": doc_task.id if doc_task else None, "resp_task_id": task_id, "message": "Tasks enqueued, connect to WebSocket for streaming response" }) except Exception as e: await logger.error(f"Error in send_message: {str(e)}") @api.websocket("/ws/response/{task_id}") async def websocket_response(websocket: WebSocket, task_id: str): await websocket.accept() try: last_index = 0 while True: status = await redis_client.get(f"response:{task_id}:status") if status == "completed": chunks = await redis_client.lrange(f"response:{task_id}:chunks", last_index, -1) for chunk in chunks: await websocket.send_text(json.loads(chunk)["chunk"]) await websocket.send_json({"status": "completed"}) break elif status == "failed": error = await redis_client.get(f"response:{task_id}:error") or "Unknown error" await websocket.send_json({"status": "failed", "error": error}) break elif status == "streaming": chunks = await redis_client.lrange(f"response:{task_id}:chunks", last_index, -1) for chunk in chunks: await websocket.send_text(json.loads(chunk)["chunk"]) last_index += len(chunks) await asyncio.sleep(0.1) except Exception as e: await logger.error(f"Error at websocket: {e}") @api.get("/task_status/{task_id}") async def get_task_status(task_id: str): task = AsyncResult(task_id, app=app) status = await redis_client.get(f"response:{task_id}:status") or task.state if status in ["PENDING", "STARTED"]: return JSONResponse({"task_id": task_id, "status": "pending"}) elif status in ["SUCCESS", "completed"]: chunks = await redis_client.lrange(f"response:{task_id}:chunks", 0, -1) return JSONResponse({"task_id": task_id, "status": "success", "chunks": [json.loads(c)["chunk"] for c in chunks]}) else: error = await redis_client.get(f"response:{task_id}:error") or str(task.info) return JSONResponse({"task_id": task_id, "status": status, "error": error}) @api.post("/replace_message") async def replace_message(request: Request): data = await request.json() async with aiofiles.open(os.path.join(BASE_DIR, "models_io", "response.txt"), "w") as f: await f.write(data.get("message", "")) return JSONResponse({"updated_message": await add_links(data.get("message", ""))}) @api.get("/viewer") async def show_document(request: Request, path: str, page: Optional[int] = 1, lines: Optional[str] = "1-1", start: Optional[int] = 0): if not await path_is_valid(path): return HTTPException(status_code=404, detail="Document not found") ext = path.split(".")[-1] if ext == "pdf": return await PDFHandler(request, path=path, page=page, templates=templates) elif ext in ("txt", "csv", "md", "json"): return await TextHandler(request, path=path, lines=lines, templates=templates) elif ext in ("docx", "doc"): return await TextHandler( request, path=path, lines=lines, templates=templates ) else: return FileResponse(path=path) # <--------------------------------- Get ---------------------------------> @api.get("/cookie_test") async def test_cookie(request: Request): return await check_cookie(request) @api.get("/test") async def test(request: Request): user = await get_current_user() return { "user": { "id": user.id, } } @api.get("/chats/id={chat_id}") async def show_chat(request: Request, chat_id: str): current_template = "pages/chat.html" chat = await get_chat_with_messages(chat_id) user = await extract_user_from_context(request) await logger.info(f"User in chats ----------------> {user.id}") await update_title(chat["chat_id"]) if not await protect_chat(user, chat_id): raise HTTPException(401, "You do not have rights to use this chat!") context = await extend_context({"request": request, "user": user}, selected=chat_id) context.update(chat) return templates.TemplateResponse(current_template, context) @api.get("/") async def last_user_chat(request: Request): user = await extract_user_from_context(request) chat = await get_latest_chat(user) url = None if chat is None: if settings.debug: await logger.info("Creating new chat") new_chat = await create_new_chat("new chat", user) url = new_chat.get("url") try: await create_collection(user, new_chat.get("chat_id"), rag) except Exception as e: raise HTTPException(500, e) else: url = f"/chats/id={chat.id}" return RedirectResponse(url, status_code=303) # <--------------------------------- Post ---------------------------------> @api.post("/new_chat") async def create_chat(request: Request, title: Optional[str] = "new chat"): user = await extract_user_from_context(request) new_chat = await create_new_chat(title, user) url = new_chat.get("url") chat_id = new_chat.get("chat_id") if url is None or chat_id is None: raise HTTPException(500, "New chat was not created") try: await create_collection(user, chat_id, rag) except Exception as e: raise HTTPException(500, e) return RedirectResponse(url, status_code=303)