import os import torch from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from lawchatbot.config import AppConfig from lawchatbot.weaviate_client import initialize_weaviate_client from lawchatbot.vectorstore import initialize_vector_store from lawchatbot.retrievers import ( initialize_semantic_retriever, initialize_bm25_retriever, initialize_hybrid_retriever, wrap_retriever_with_source ) from lawchatbot.rag_chain import initialize_llm, build_rag_chain, run_rag_query # GPU optimization setup def setup_gpu_optimization(): """Configure GPU settings for optimal performance""" if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True print(f"๐Ÿš€ GPU detected: {torch.cuda.get_device_name(0)}") print(f"๐Ÿ’พ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") else: print("โš ๏ธ No GPU detected, using CPU") app = FastAPI() # Set up static and template directories (relative to this file) app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") from threading import Lock _init_lock = Lock() _system = {} @app.on_event("startup") def startup_event(): # Setup GPU optimization first setup_gpu_optimization() with _init_lock: if not _system: # Debug: Print environment variables (mask sensitive parts) weaviate_url = os.getenv("WEAVIATE_URL", "") weaviate_key = os.getenv("WEAVIATE_API_KEY", "") openai_key = os.getenv("OPENAI_API_KEY", "") print(f"๐Ÿ” Debug - WEAVIATE_URL: {weaviate_url[:50]}...") print(f"๐Ÿ” Debug - WEAVIATE_API_KEY: {'SET' if weaviate_key != '' else 'NOT_SET'}") print(f"๐Ÿ” Debug - OPENAI_API_KEY: {'SET' if openai_key != '' else 'NOT_SET'}") if weaviate_url == "" or weaviate_key == "" or openai_key == "": print("โŒ Error: Missing environment variables!") print("Please set WEAVIATE_URL, WEAVIATE_API_KEY, and OPENAI_API_KEY in Hugging Face Secrets") return config = AppConfig( weaviate_url=weaviate_url, weaviate_api_key=weaviate_key, openai_api_key=openai_key, weaviate_class="JustiaFederalCases", text_key="text", metadata_attributes=["text"], semantic_k=10, bm25_k=10, alpha=0.5 ) print("๐Ÿ”„ Initializing system components with GPU acceleration...") try: print("๐Ÿ”— Connecting to Weaviate (cloud)...") client = initialize_weaviate_client(config) print("โœ… Weaviate client connected successfully!") vectorstore = initialize_vector_store(client, config) semantic_ret = initialize_semantic_retriever(vectorstore, config) bm25_ret = initialize_bm25_retriever(client, config) hybrid_ret = initialize_hybrid_retriever(semantic_ret, bm25_ret, alpha=config.alpha) hybrid_ret = wrap_retriever_with_source(hybrid_ret) llm = initialize_llm() rag_chain = build_rag_chain(llm, hybrid_ret) _system.update({ "client": client, "rag_chain": rag_chain }) print("โณ Pre-warming system with dummy query...") dummy_question = "This is a warmup question." rag_chain.invoke({"question": dummy_question}) # Clear GPU cache after warmup if torch.cuda.is_available(): torch.cuda.empty_cache() print("โœ… System pre-warmed and ready for fast GPU-accelerated responses.") except Exception as e: print(f"โŒ Initialization failed: {str(e)}") print(f"Error type: {type(e).__name__}") # Don't crash the app, just log the error return def get_system(): return _system @app.on_event("shutdown") def shutdown_event(): sys = _system.get("client") if sys: try: sys.close() except Exception: pass # Clear GPU memory on shutdown if torch.cuda.is_available(): torch.cuda.empty_cache() print("๐Ÿงน GPU memory cleared on shutdown") @app.get("/", response_class=HTMLResponse) def chat_page(request: Request): return templates.TemplateResponse(request, "chat.html") @app.get("/health") def health_check(): """Health check endpoint""" sys = get_system() if sys.get("rag_chain"): return {"status": "healthy", "message": "LawChatbot is ready"} else: return {"status": "unhealthy", "message": "System not initialized"} @app.post("/chat") async def chat_api(request: Request): data = await request.json() question = data.get("question", "") sys = get_system() rag_chain = sys.get("rag_chain") if not rag_chain: return JSONResponse({ "answer": "โŒ System not initialized. Please check the logs and ensure environment variables are set correctly.", "context": [] }, status_code=503) try: response = rag_chain.invoke({"question": question}) answer = response.get("answer", "") documents = response.get("source_documents", []) context = [doc.page_content for doc in documents] if documents else [] return JSONResponse({"answer": answer, "context": context}) except Exception as e: return JSONResponse({"answer": f"Error: {str(e)}", "context": []}, status_code=500)