File size: 5,930 Bytes
88208e2
51a0302
 
 
a71106f
766487e
d75dc74
51a0302
d75dc74
88208e2
d75dc74
 
88208e2
d75dc74
 
 
 
 
 
ac3f3ed
d75dc74
 
 
 
88208e2
d75dc74
 
 
88208e2
d75dc74
 
 
 
88208e2
d75dc74
 
 
 
 
 
88208e2
d75dc74
 
 
 
88208e2
d75dc74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88208e2
d75dc74
 
51a0302
 
 
d75dc74
51a0302
d75dc74
51a0302
d75dc74
 
 
 
51a0302
 
 
 
d75dc74
51a0302
d75dc74
 
51a0302
 
0cf14e5
d75dc74
 
 
 
 
51a0302
d75dc74
 
 
 
 
51a0302
d75dc74
51a0302
d75dc74
 
 
51a0302
 
d75dc74
 
51a0302
d75dc74
 
 
51a0302
d75dc74
 
 
 
 
 
 
 
 
 
 
 
 
51a0302
 
d75dc74
 
 
51a0302
 
d75dc74
 
 
 
 
 
 
 
 
 
51a0302
766487e
 
 
 
51a0302
766487e
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# app.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import uvicorn
import logging
from src.RAGSample import setup_retriever, setup_rag_chain, RAGApplication
import os
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Create FastAPI app
app = FastAPI(
    title="RAG API",
    description="A REST API for Retrieval-Augmented Generation using local vector database",
    version="1.0.0"
)

# Initialize RAG components (this will be done once when the server starts)
retriever = None
rag_chain = None
rag_application = None

# Pydantic model for request
class QuestionRequest(BaseModel):
    question: str

# Pydantic model for response
class QuestionResponse(BaseModel):
    question: str
    answer: str

@app.on_event("startup")
async def startup_event():
    """Initialize RAG components when the server starts."""
    global retriever, rag_chain, rag_application
    try:
        print("Initializing RAG components...")
        
        # Check if Kaggle credentials are provided via environment variables
        kaggle_username = os.getenv("KAGGLE_USERNAME")
        kaggle_key = os.getenv("KAGGLE_KEY")
        kaggle_dataset = os.getenv("KAGGLE_DATASET")
        
        # If no environment variables, try to load from kaggle.json
        if not (kaggle_username and kaggle_key):
            try:
                from src.kaggle_loader import KaggleDataLoader
                # Test if we can create a loader (this will auto-load from kaggle.json)
                test_loader = KaggleDataLoader()
                if test_loader.kaggle_username and test_loader.kaggle_key:
                    kaggle_username = test_loader.kaggle_username
                    kaggle_key = test_loader.kaggle_key
                    print(f"Loaded Kaggle credentials from kaggle.json: {kaggle_username}")
            except Exception as e:
                print(f"Could not load Kaggle credentials from kaggle.json: {e}")
        
        if kaggle_username and kaggle_key and kaggle_dataset:
            print(f"Loading Kaggle dataset: {kaggle_dataset}")
            retriever = setup_retriever(
                use_kaggle_data=True,
                kaggle_dataset=kaggle_dataset,
                kaggle_username=kaggle_username,
                kaggle_key=kaggle_key
            )
        else:
            print("Loading mental health FAQ data from local file...")
            # Load mental health FAQ data from local file (default behavior)
            retriever = setup_retriever()
        
        rag_chain = setup_rag_chain()
        rag_application = RAGApplication(retriever, rag_chain)
        print("RAG components initialized successfully!")
    except Exception as e:
        print(f"Error initializing RAG components: {e}")
        raise

@app.get("/")
async def root():
    """Root endpoint with API information."""
    return {
        "message": "RAG API is running",
        "endpoints": {
            "ask_question": "/ask",
            "health_check": "/health",
            "load_kaggle_dataset": "/load-kaggle-dataset"
        }
    }

@app.get("/health")
async def health_check():
    """Health check endpoint."""
    return {
        "status": "healthy",
        "rag_initialized": rag_application is not None
    }

@app.post("/medical/ask", response_model=QuestionResponse)
async def ask_question(request: QuestionRequest):
    """Ask a question and get an answer using RAG."""
    if rag_application is None:
        raise HTTPException(status_code=500, detail="RAG application not initialized")
    
    try:
        print(f"Processing question: {request.question}")
        
        # Debug: Check what retriever we're using
        retriever_type = type(rag_application.retriever).__name__
        print(f"DEBUG: Using retriever type: {retriever_type}")
        
        answer = rag_application.run(request.question)
        
        return QuestionResponse(
            question=request.question,
            answer=answer
        )
    except Exception as e:
        print(f"Error processing question: {e}")
        raise HTTPException(status_code=500, detail=f"Error processing question: {str(e)}")

@app.post("/load-kaggle-dataset")
async def load_kaggle_dataset(dataset_name: str):
    """Load a Kaggle dataset for RAG."""
    try:
        from src.kaggle_loader import KaggleDataLoader
        
        # Create loader without parameters - it will auto-load from kaggle.json
        loader = KaggleDataLoader()
        
        # Download the dataset
        dataset_path = loader.download_dataset(dataset_name)
        
        # Reload the retriever with the new dataset
        global rag_application
        retriever = setup_retriever(use_kaggle_data=True, kaggle_dataset=dataset_name)
        rag_chain = setup_rag_chain()
        rag_application = RAGApplication(retriever, rag_chain)
        
        return {
            "status": "success",
            "message": f"Dataset {dataset_name} loaded successfully",
            "dataset_path": dataset_path
        }
    except Exception as e:
        return {"status": "error", "message": str(e)}

@app.get("/models")
async def get_models():
    """Get information about available models."""
    return {
        "llm_model": "dolphin-llama3:8b",
        "embedding_model": "TF-IDF embeddings",
        "vector_database": "ChromaDB (local)"
    }


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

if __name__ == "__main__":
    try:
        logger.info("Starting application...")
        # Add any initialization code here with try/except blocks
        
        port = int(os.getenv("PORT", 7860))
        logger.info(f"Starting server on port {port}")
        
        uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")
    except Exception as e:
        logger.error(f"Failed to start application: {e}")
        raise