File size: 6,203 Bytes
f114412
faca925
f114412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
faca925
 
 
 
 
 
 
 
 
 
 
 
f114412
 
 
 
 
 
 
 
 
 
 
 
faca925
 
 
f114412
 
0c1300d
 
 
 
 
 
 
 
 
 
 
 
 
 
f114412
0c1300d
 
 
f114412
 
 
 
 
 
 
faca925
 
f114412
0c1300d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f114412
 
 
faca925
 
 
 
 
 
f114412
0c1300d
 
 
 
f114412
 
 
 
 
 
 
 
 
 
 
 
faca925
 
 
 
 
f114412
 
 
229e823
f114412
0c1300d
 
 
 
 
 
 
 
 
f114412
 
 
 
 
0c1300d
 
 
 
 
 
 
 
f114412
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import torch
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from lawchatbot.config import AppConfig
from lawchatbot.weaviate_client import initialize_weaviate_client
from lawchatbot.vectorstore import initialize_vector_store
from lawchatbot.retrievers import (
    initialize_semantic_retriever,
    initialize_bm25_retriever,
    initialize_hybrid_retriever,
    wrap_retriever_with_source
)
from lawchatbot.rag_chain import initialize_llm, build_rag_chain, run_rag_query

# GPU optimization setup
def setup_gpu_optimization():
    """Configure GPU settings for optimal performance"""
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        print(f"🚀 GPU detected: {torch.cuda.get_device_name(0)}")
        print(f"💾 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    else:
        print("⚠️ No GPU detected, using CPU")

app = FastAPI()

# Set up static and template directories (relative to this file)
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")

from threading import Lock
_init_lock = Lock()
_system = {}

@app.on_event("startup")
def startup_event():
    # Setup GPU optimization first
    setup_gpu_optimization()
    
    with _init_lock:
        if not _system:
            # Debug: Print environment variables (mask sensitive parts)
            weaviate_url = os.getenv("WEAVIATE_URL", "<not_set>")
            weaviate_key = os.getenv("WEAVIATE_API_KEY", "<not_set>")
            openai_key = os.getenv("OPENAI_API_KEY", "<not_set>")
            
            print(f"🔍 Debug - WEAVIATE_URL: {weaviate_url[:50]}...")
            print(f"🔍 Debug - WEAVIATE_API_KEY: {'SET' if weaviate_key != '<not_set>' else 'NOT_SET'}")
            print(f"🔍 Debug - OPENAI_API_KEY: {'SET' if openai_key != '<not_set>' else 'NOT_SET'}")
            
            if weaviate_url == "<not_set>" or weaviate_key == "<not_set>" or openai_key == "<not_set>":
                print("❌ Error: Missing environment variables!")
                print("Please set WEAVIATE_URL, WEAVIATE_API_KEY, and OPENAI_API_KEY in Hugging Face Secrets")
                return
            
            config = AppConfig(
                weaviate_url=weaviate_url,
                weaviate_api_key=weaviate_key,
                openai_api_key=openai_key,
                weaviate_class="JustiaFederalCases",
                text_key="text",
                metadata_attributes=["text"],
                semantic_k=10,
                bm25_k=10,
                alpha=0.5
            )
            
            print("🔄 Initializing system components with GPU acceleration...")
            try:
                print("🔗 Connecting to Weaviate (cloud)...")
                client = initialize_weaviate_client(config)
                print("✅ Weaviate client connected successfully!")
                
                vectorstore = initialize_vector_store(client, config)
                semantic_ret = initialize_semantic_retriever(vectorstore, config)
                bm25_ret = initialize_bm25_retriever(client, config)
                hybrid_ret = initialize_hybrid_retriever(semantic_ret, bm25_ret, alpha=config.alpha)
                hybrid_ret = wrap_retriever_with_source(hybrid_ret)
                llm = initialize_llm()
                rag_chain = build_rag_chain(llm, hybrid_ret)
                
                _system.update({
                    "client": client,
                    "rag_chain": rag_chain
                })
                
                print("⏳ Pre-warming system with dummy query...")
                dummy_question = "This is a warmup question."
                rag_chain.invoke({"question": dummy_question})
                
                # Clear GPU cache after warmup
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    
                print("✅ System pre-warmed and ready for fast GPU-accelerated responses.")
            except Exception as e:
                print(f"❌ Initialization failed: {str(e)}")
                print(f"Error type: {type(e).__name__}")
                # Don't crash the app, just log the error
                return

def get_system():
    return _system

@app.on_event("shutdown")
def shutdown_event():
    sys = _system.get("client")
    if sys:
        try:
            sys.close()
        except Exception:
            pass
    
    # Clear GPU memory on shutdown
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print("🧹 GPU memory cleared on shutdown")

@app.get("/", response_class=HTMLResponse)
def chat_page(request: Request):
    return templates.TemplateResponse(request, "chat.html")

@app.get("/health")
def health_check():
    """Health check endpoint"""
    sys = get_system()
    if sys.get("rag_chain"):
        return {"status": "healthy", "message": "LawChatbot is ready"}
    else:
        return {"status": "unhealthy", "message": "System not initialized"}

@app.post("/chat")
async def chat_api(request: Request):
    data = await request.json()
    question = data.get("question", "")
    sys = get_system()
    rag_chain = sys.get("rag_chain")
    
    if not rag_chain:
        return JSONResponse({
            "answer": "❌ System not initialized. Please check the logs and ensure environment variables are set correctly.", 
            "context": []
        }, status_code=503)
    
    try:
        response = rag_chain.invoke({"question": question})
        answer = response.get("answer", "")
        documents = response.get("source_documents", [])
        context = [doc.page_content for doc in documents] if documents else []
        return JSONResponse({"answer": answer, "context": context})
    except Exception as e:
        return JSONResponse({"answer": f"Error: {str(e)}", "context": []}, status_code=500)