File size: 5,930 Bytes
88208e2 51a0302 a71106f 766487e d75dc74 51a0302 d75dc74 88208e2 d75dc74 88208e2 d75dc74 ac3f3ed d75dc74 88208e2 d75dc74 88208e2 d75dc74 88208e2 d75dc74 88208e2 d75dc74 88208e2 d75dc74 88208e2 d75dc74 51a0302 d75dc74 51a0302 d75dc74 51a0302 d75dc74 51a0302 d75dc74 51a0302 d75dc74 51a0302 0cf14e5 d75dc74 51a0302 d75dc74 51a0302 d75dc74 51a0302 d75dc74 51a0302 d75dc74 51a0302 d75dc74 51a0302 d75dc74 51a0302 d75dc74 51a0302 d75dc74 51a0302 766487e 51a0302 766487e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
# app.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import uvicorn
import logging
from src.RAGSample import setup_retriever, setup_rag_chain, RAGApplication
import os
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Create FastAPI app
app = FastAPI(
title="RAG API",
description="A REST API for Retrieval-Augmented Generation using local vector database",
version="1.0.0"
)
# Initialize RAG components (this will be done once when the server starts)
retriever = None
rag_chain = None
rag_application = None
# Pydantic model for request
class QuestionRequest(BaseModel):
question: str
# Pydantic model for response
class QuestionResponse(BaseModel):
question: str
answer: str
@app.on_event("startup")
async def startup_event():
"""Initialize RAG components when the server starts."""
global retriever, rag_chain, rag_application
try:
print("Initializing RAG components...")
# Check if Kaggle credentials are provided via environment variables
kaggle_username = os.getenv("KAGGLE_USERNAME")
kaggle_key = os.getenv("KAGGLE_KEY")
kaggle_dataset = os.getenv("KAGGLE_DATASET")
# If no environment variables, try to load from kaggle.json
if not (kaggle_username and kaggle_key):
try:
from src.kaggle_loader import KaggleDataLoader
# Test if we can create a loader (this will auto-load from kaggle.json)
test_loader = KaggleDataLoader()
if test_loader.kaggle_username and test_loader.kaggle_key:
kaggle_username = test_loader.kaggle_username
kaggle_key = test_loader.kaggle_key
print(f"Loaded Kaggle credentials from kaggle.json: {kaggle_username}")
except Exception as e:
print(f"Could not load Kaggle credentials from kaggle.json: {e}")
if kaggle_username and kaggle_key and kaggle_dataset:
print(f"Loading Kaggle dataset: {kaggle_dataset}")
retriever = setup_retriever(
use_kaggle_data=True,
kaggle_dataset=kaggle_dataset,
kaggle_username=kaggle_username,
kaggle_key=kaggle_key
)
else:
print("Loading mental health FAQ data from local file...")
# Load mental health FAQ data from local file (default behavior)
retriever = setup_retriever()
rag_chain = setup_rag_chain()
rag_application = RAGApplication(retriever, rag_chain)
print("RAG components initialized successfully!")
except Exception as e:
print(f"Error initializing RAG components: {e}")
raise
@app.get("/")
async def root():
"""Root endpoint with API information."""
return {
"message": "RAG API is running",
"endpoints": {
"ask_question": "/ask",
"health_check": "/health",
"load_kaggle_dataset": "/load-kaggle-dataset"
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"rag_initialized": rag_application is not None
}
@app.post("/medical/ask", response_model=QuestionResponse)
async def ask_question(request: QuestionRequest):
"""Ask a question and get an answer using RAG."""
if rag_application is None:
raise HTTPException(status_code=500, detail="RAG application not initialized")
try:
print(f"Processing question: {request.question}")
# Debug: Check what retriever we're using
retriever_type = type(rag_application.retriever).__name__
print(f"DEBUG: Using retriever type: {retriever_type}")
answer = rag_application.run(request.question)
return QuestionResponse(
question=request.question,
answer=answer
)
except Exception as e:
print(f"Error processing question: {e}")
raise HTTPException(status_code=500, detail=f"Error processing question: {str(e)}")
@app.post("/load-kaggle-dataset")
async def load_kaggle_dataset(dataset_name: str):
"""Load a Kaggle dataset for RAG."""
try:
from src.kaggle_loader import KaggleDataLoader
# Create loader without parameters - it will auto-load from kaggle.json
loader = KaggleDataLoader()
# Download the dataset
dataset_path = loader.download_dataset(dataset_name)
# Reload the retriever with the new dataset
global rag_application
retriever = setup_retriever(use_kaggle_data=True, kaggle_dataset=dataset_name)
rag_chain = setup_rag_chain()
rag_application = RAGApplication(retriever, rag_chain)
return {
"status": "success",
"message": f"Dataset {dataset_name} loaded successfully",
"dataset_path": dataset_path
}
except Exception as e:
return {"status": "error", "message": str(e)}
@app.get("/models")
async def get_models():
"""Get information about available models."""
return {
"llm_model": "dolphin-llama3:8b",
"embedding_model": "TF-IDF embeddings",
"vector_database": "ChromaDB (local)"
}
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
if __name__ == "__main__":
try:
logger.info("Starting application...")
# Add any initialization code here with try/except blocks
port = int(os.getenv("PORT", 7860))
logger.info(f"Starting server on port {port}")
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")
except Exception as e:
logger.error(f"Failed to start application: {e}")
raise |