Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from fastapi import FastAPI, Request | |
from fastapi.responses import HTMLResponse, JSONResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
from lawchatbot.config import AppConfig | |
from lawchatbot.weaviate_client import initialize_weaviate_client | |
from lawchatbot.vectorstore import initialize_vector_store | |
from lawchatbot.retrievers import ( | |
initialize_semantic_retriever, | |
initialize_bm25_retriever, | |
initialize_hybrid_retriever, | |
wrap_retriever_with_source | |
) | |
from lawchatbot.rag_chain import initialize_llm, build_rag_chain, run_rag_query | |
# GPU optimization setup | |
def setup_gpu_optimization(): | |
"""Configure GPU settings for optimal performance""" | |
if torch.cuda.is_available(): | |
torch.backends.cudnn.benchmark = True | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
print(f"🚀 GPU detected: {torch.cuda.get_device_name(0)}") | |
print(f"💾 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") | |
else: | |
print("⚠️ No GPU detected, using CPU") | |
app = FastAPI() | |
# Set up static and template directories (relative to this file) | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
templates = Jinja2Templates(directory="templates") | |
from threading import Lock | |
_init_lock = Lock() | |
_system = {} | |
def startup_event(): | |
# Setup GPU optimization first | |
setup_gpu_optimization() | |
with _init_lock: | |
if not _system: | |
# Debug: Print environment variables (mask sensitive parts) | |
weaviate_url = os.getenv("WEAVIATE_URL", "<not_set>") | |
weaviate_key = os.getenv("WEAVIATE_API_KEY", "<not_set>") | |
openai_key = os.getenv("OPENAI_API_KEY", "<not_set>") | |
print(f"🔍 Debug - WEAVIATE_URL: {weaviate_url[:50]}...") | |
print(f"🔍 Debug - WEAVIATE_API_KEY: {'SET' if weaviate_key != '<not_set>' else 'NOT_SET'}") | |
print(f"🔍 Debug - OPENAI_API_KEY: {'SET' if openai_key != '<not_set>' else 'NOT_SET'}") | |
if weaviate_url == "<not_set>" or weaviate_key == "<not_set>" or openai_key == "<not_set>": | |
print("❌ Error: Missing environment variables!") | |
print("Please set WEAVIATE_URL, WEAVIATE_API_KEY, and OPENAI_API_KEY in Hugging Face Secrets") | |
return | |
config = AppConfig( | |
weaviate_url=weaviate_url, | |
weaviate_api_key=weaviate_key, | |
openai_api_key=openai_key, | |
weaviate_class="JustiaFederalCases", | |
text_key="text", | |
metadata_attributes=["text"], | |
semantic_k=10, | |
bm25_k=10, | |
alpha=0.5 | |
) | |
print("🔄 Initializing system components with GPU acceleration...") | |
try: | |
print("🔗 Connecting to Weaviate (cloud)...") | |
client = initialize_weaviate_client(config) | |
print("✅ Weaviate client connected successfully!") | |
vectorstore = initialize_vector_store(client, config) | |
semantic_ret = initialize_semantic_retriever(vectorstore, config) | |
bm25_ret = initialize_bm25_retriever(client, config) | |
hybrid_ret = initialize_hybrid_retriever(semantic_ret, bm25_ret, alpha=config.alpha) | |
hybrid_ret = wrap_retriever_with_source(hybrid_ret) | |
llm = initialize_llm() | |
rag_chain = build_rag_chain(llm, hybrid_ret) | |
_system.update({ | |
"client": client, | |
"rag_chain": rag_chain | |
}) | |
print("⏳ Pre-warming system with dummy query...") | |
dummy_question = "This is a warmup question." | |
rag_chain.invoke({"question": dummy_question}) | |
# Clear GPU cache after warmup | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
print("✅ System pre-warmed and ready for fast GPU-accelerated responses.") | |
except Exception as e: | |
print(f"❌ Initialization failed: {str(e)}") | |
print(f"Error type: {type(e).__name__}") | |
# Don't crash the app, just log the error | |
return | |
def get_system(): | |
return _system | |
def shutdown_event(): | |
sys = _system.get("client") | |
if sys: | |
try: | |
sys.close() | |
except Exception: | |
pass | |
# Clear GPU memory on shutdown | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
print("🧹 GPU memory cleared on shutdown") | |
def chat_page(request: Request): | |
return templates.TemplateResponse(request, "chat.html") | |
def health_check(): | |
"""Health check endpoint""" | |
sys = get_system() | |
if sys.get("rag_chain"): | |
return {"status": "healthy", "message": "LawChatbot is ready"} | |
else: | |
return {"status": "unhealthy", "message": "System not initialized"} | |
async def chat_api(request: Request): | |
data = await request.json() | |
question = data.get("question", "") | |
sys = get_system() | |
rag_chain = sys.get("rag_chain") | |
if not rag_chain: | |
return JSONResponse({ | |
"answer": "❌ System not initialized. Please check the logs and ensure environment variables are set correctly.", | |
"context": [] | |
}, status_code=503) | |
try: | |
response = rag_chain.invoke({"question": question}) | |
answer = response.get("answer", "") | |
documents = response.get("source_documents", []) | |
context = [doc.page_content for doc in documents] if documents else [] | |
return JSONResponse({"answer": answer, "context": context}) | |
except Exception as e: | |
return JSONResponse({"answer": f"Error: {str(e)}", "context": []}, status_code=500) | |