brendon-ai's picture
Update app.py
0cf14e5 verified
# app.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import uvicorn
import logging
from src.RAGSample import setup_retriever, setup_rag_chain, RAGApplication
import os
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Create FastAPI app
app = FastAPI(
title="RAG API",
description="A REST API for Retrieval-Augmented Generation using local vector database",
version="1.0.0"
)
# Initialize RAG components (this will be done once when the server starts)
retriever = None
rag_chain = None
rag_application = None
# Pydantic model for request
class QuestionRequest(BaseModel):
question: str
# Pydantic model for response
class QuestionResponse(BaseModel):
question: str
answer: str
@app.on_event("startup")
async def startup_event():
"""Initialize RAG components when the server starts."""
global retriever, rag_chain, rag_application
try:
print("Initializing RAG components...")
# Check if Kaggle credentials are provided via environment variables
kaggle_username = os.getenv("KAGGLE_USERNAME")
kaggle_key = os.getenv("KAGGLE_KEY")
kaggle_dataset = os.getenv("KAGGLE_DATASET")
# If no environment variables, try to load from kaggle.json
if not (kaggle_username and kaggle_key):
try:
from src.kaggle_loader import KaggleDataLoader
# Test if we can create a loader (this will auto-load from kaggle.json)
test_loader = KaggleDataLoader()
if test_loader.kaggle_username and test_loader.kaggle_key:
kaggle_username = test_loader.kaggle_username
kaggle_key = test_loader.kaggle_key
print(f"Loaded Kaggle credentials from kaggle.json: {kaggle_username}")
except Exception as e:
print(f"Could not load Kaggle credentials from kaggle.json: {e}")
if kaggle_username and kaggle_key and kaggle_dataset:
print(f"Loading Kaggle dataset: {kaggle_dataset}")
retriever = setup_retriever(
use_kaggle_data=True,
kaggle_dataset=kaggle_dataset,
kaggle_username=kaggle_username,
kaggle_key=kaggle_key
)
else:
print("Loading mental health FAQ data from local file...")
# Load mental health FAQ data from local file (default behavior)
retriever = setup_retriever()
rag_chain = setup_rag_chain()
rag_application = RAGApplication(retriever, rag_chain)
print("RAG components initialized successfully!")
except Exception as e:
print(f"Error initializing RAG components: {e}")
raise
@app.get("/")
async def root():
"""Root endpoint with API information."""
return {
"message": "RAG API is running",
"endpoints": {
"ask_question": "/ask",
"health_check": "/health",
"load_kaggle_dataset": "/load-kaggle-dataset"
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"rag_initialized": rag_application is not None
}
@app.post("/medical/ask", response_model=QuestionResponse)
async def ask_question(request: QuestionRequest):
"""Ask a question and get an answer using RAG."""
if rag_application is None:
raise HTTPException(status_code=500, detail="RAG application not initialized")
try:
print(f"Processing question: {request.question}")
# Debug: Check what retriever we're using
retriever_type = type(rag_application.retriever).__name__
print(f"DEBUG: Using retriever type: {retriever_type}")
answer = rag_application.run(request.question)
return QuestionResponse(
question=request.question,
answer=answer
)
except Exception as e:
print(f"Error processing question: {e}")
raise HTTPException(status_code=500, detail=f"Error processing question: {str(e)}")
@app.post("/load-kaggle-dataset")
async def load_kaggle_dataset(dataset_name: str):
"""Load a Kaggle dataset for RAG."""
try:
from src.kaggle_loader import KaggleDataLoader
# Create loader without parameters - it will auto-load from kaggle.json
loader = KaggleDataLoader()
# Download the dataset
dataset_path = loader.download_dataset(dataset_name)
# Reload the retriever with the new dataset
global rag_application
retriever = setup_retriever(use_kaggle_data=True, kaggle_dataset=dataset_name)
rag_chain = setup_rag_chain()
rag_application = RAGApplication(retriever, rag_chain)
return {
"status": "success",
"message": f"Dataset {dataset_name} loaded successfully",
"dataset_path": dataset_path
}
except Exception as e:
return {"status": "error", "message": str(e)}
@app.get("/models")
async def get_models():
"""Get information about available models."""
return {
"llm_model": "dolphin-llama3:8b",
"embedding_model": "TF-IDF embeddings",
"vector_database": "ChromaDB (local)"
}
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
if __name__ == "__main__":
try:
logger.info("Starting application...")
# Add any initialization code here with try/except blocks
port = int(os.getenv("PORT", 7860))
logger.info(f"Starting server on port {port}")
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")
except Exception as e:
logger.error(f"Failed to start application: {e}")
raise