|
|
|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
from typing import Optional |
|
import uvicorn |
|
import logging |
|
from src.RAGSample import setup_retriever, setup_rag_chain, RAGApplication |
|
import os |
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
app = FastAPI( |
|
title="RAG API", |
|
description="A REST API for Retrieval-Augmented Generation using local vector database", |
|
version="1.0.0" |
|
) |
|
|
|
|
|
retriever = None |
|
rag_chain = None |
|
rag_application = None |
|
|
|
|
|
class QuestionRequest(BaseModel): |
|
question: str |
|
|
|
|
|
class QuestionResponse(BaseModel): |
|
question: str |
|
answer: str |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
"""Initialize RAG components when the server starts.""" |
|
global retriever, rag_chain, rag_application |
|
try: |
|
print("Initializing RAG components...") |
|
|
|
|
|
kaggle_username = os.getenv("KAGGLE_USERNAME") |
|
kaggle_key = os.getenv("KAGGLE_KEY") |
|
kaggle_dataset = os.getenv("KAGGLE_DATASET") |
|
|
|
|
|
if not (kaggle_username and kaggle_key): |
|
try: |
|
from src.kaggle_loader import KaggleDataLoader |
|
|
|
test_loader = KaggleDataLoader() |
|
if test_loader.kaggle_username and test_loader.kaggle_key: |
|
kaggle_username = test_loader.kaggle_username |
|
kaggle_key = test_loader.kaggle_key |
|
print(f"Loaded Kaggle credentials from kaggle.json: {kaggle_username}") |
|
except Exception as e: |
|
print(f"Could not load Kaggle credentials from kaggle.json: {e}") |
|
|
|
if kaggle_username and kaggle_key and kaggle_dataset: |
|
print(f"Loading Kaggle dataset: {kaggle_dataset}") |
|
retriever = setup_retriever( |
|
use_kaggle_data=True, |
|
kaggle_dataset=kaggle_dataset, |
|
kaggle_username=kaggle_username, |
|
kaggle_key=kaggle_key |
|
) |
|
else: |
|
print("Loading mental health FAQ data from local file...") |
|
|
|
retriever = setup_retriever() |
|
|
|
rag_chain = setup_rag_chain() |
|
rag_application = RAGApplication(retriever, rag_chain) |
|
print("RAG components initialized successfully!") |
|
except Exception as e: |
|
print(f"Error initializing RAG components: {e}") |
|
raise |
|
|
|
@app.get("/") |
|
async def root(): |
|
"""Root endpoint with API information.""" |
|
return { |
|
"message": "RAG API is running", |
|
"endpoints": { |
|
"ask_question": "/ask", |
|
"health_check": "/health", |
|
"load_kaggle_dataset": "/load-kaggle-dataset" |
|
} |
|
} |
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
"""Health check endpoint.""" |
|
return { |
|
"status": "healthy", |
|
"rag_initialized": rag_application is not None |
|
} |
|
|
|
@app.post("/medical/ask", response_model=QuestionResponse) |
|
async def ask_question(request: QuestionRequest): |
|
"""Ask a question and get an answer using RAG.""" |
|
if rag_application is None: |
|
raise HTTPException(status_code=500, detail="RAG application not initialized") |
|
|
|
try: |
|
print(f"Processing question: {request.question}") |
|
|
|
|
|
retriever_type = type(rag_application.retriever).__name__ |
|
print(f"DEBUG: Using retriever type: {retriever_type}") |
|
|
|
answer = rag_application.run(request.question) |
|
|
|
return QuestionResponse( |
|
question=request.question, |
|
answer=answer |
|
) |
|
except Exception as e: |
|
print(f"Error processing question: {e}") |
|
raise HTTPException(status_code=500, detail=f"Error processing question: {str(e)}") |
|
|
|
@app.post("/load-kaggle-dataset") |
|
async def load_kaggle_dataset(dataset_name: str): |
|
"""Load a Kaggle dataset for RAG.""" |
|
try: |
|
from src.kaggle_loader import KaggleDataLoader |
|
|
|
|
|
loader = KaggleDataLoader() |
|
|
|
|
|
dataset_path = loader.download_dataset(dataset_name) |
|
|
|
|
|
global rag_application |
|
retriever = setup_retriever(use_kaggle_data=True, kaggle_dataset=dataset_name) |
|
rag_chain = setup_rag_chain() |
|
rag_application = RAGApplication(retriever, rag_chain) |
|
|
|
return { |
|
"status": "success", |
|
"message": f"Dataset {dataset_name} loaded successfully", |
|
"dataset_path": dataset_path |
|
} |
|
except Exception as e: |
|
return {"status": "error", "message": str(e)} |
|
|
|
@app.get("/models") |
|
async def get_models(): |
|
"""Get information about available models.""" |
|
return { |
|
"llm_model": "dolphin-llama3:8b", |
|
"embedding_model": "TF-IDF embeddings", |
|
"vector_database": "ChromaDB (local)" |
|
} |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
if __name__ == "__main__": |
|
try: |
|
logger.info("Starting application...") |
|
|
|
|
|
port = int(os.getenv("PORT", 7860)) |
|
logger.info(f"Starting server on port {port}") |
|
|
|
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info") |
|
except Exception as e: |
|
logger.error(f"Failed to start application: {e}") |
|
raise |