Spaces:
Runtime error
Runtime error
import os | |
import warnings | |
import logging | |
import time | |
from datetime import datetime | |
from fastapi import FastAPI, Request, HTTPException, Depends, Header | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from pdf_parser import parse_pdf_from_url_multithreaded as parse_pdf_from_url, parse_pdf_from_file_multithreaded as parse_pdf_from_file | |
from embedder import build_pinecone_index, preload_model | |
from retriever import retrieve_chunks | |
from llm import query_gemini | |
import uvicorn | |
# Set up cache directory for HuggingFace models | |
cache_dir = os.path.join(os.getcwd(), ".cache") | |
os.makedirs(cache_dir, exist_ok=True) | |
os.environ['HF_HOME'] = cache_dir | |
os.environ['TRANSFORMERS_CACHE'] = cache_dir | |
# Suppress TensorFlow warnings | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' | |
os.environ['TF_LOGGING_LEVEL'] = 'ERROR' | |
os.environ['TF_ENABLE_DEPRECATION_WARNINGS'] = '0' | |
warnings.filterwarnings('ignore', category=DeprecationWarning, module='tensorflow') | |
logging.getLogger('tensorflow').setLevel(logging.ERROR) | |
app = FastAPI(title="HackRx Insurance Policy Assistant", version="1.0.0") | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Preload the model at startup | |
async def startup_event(): | |
print("Starting up HackRx Insurance Policy Assistant...") | |
print("Preloading sentence transformer model...") | |
preload_model() | |
print("Model preloading completed. API is ready to serve requests!") | |
async def root(): | |
return {"message": "HackRx Insurance Policy Assistant API is running!"} | |
async def health_check(): | |
return {"status": "healthy", "message": "API is ready to process requests"} | |
class QueryRequest(BaseModel): | |
documents: str | |
questions: list[str] | |
class LocalQueryRequest(BaseModel): | |
document_path: str | |
questions: list[str] | |
def verify_token(authorization: str = Header(None)): | |
if not authorization or not authorization.startswith("Bearer "): | |
raise HTTPException(status_code=401, detail="Invalid authorization header") | |
token = authorization.replace("Bearer ", "") | |
# For demo purposes, accept any token. In production, validate against a database | |
if not token: | |
raise HTTPException(status_code=401, detail="Invalid token") | |
return token | |
async def run_query(request: QueryRequest, token: str = Depends(verify_token)): | |
start_time = time.time() | |
timing_data = {} | |
try: | |
print(f"\n=== INPUT JSON ===") | |
print(f"Documents: {request.documents}") | |
print(f"Questions: {request.questions}") | |
print(f"==================\n") | |
print(f"Processing {len(request.questions)} questions...") | |
# Time PDF parsing | |
pdf_start = time.time() | |
text_chunks = parse_pdf_from_url(request.documents) | |
pdf_time = time.time() - pdf_start | |
timing_data['pdf_parsing'] = round(pdf_time, 2) | |
print(f"Extracted {len(text_chunks)} text chunks from PDF") | |
# Time Pinecone index building/upsert | |
index_start = time.time() | |
pinecone_index = build_pinecone_index(text_chunks) | |
index_time = time.time() - index_start | |
timing_data['pinecone_index_building'] = round(index_time, 2) | |
texts = text_chunks # for retrieve_chunks | |
# Time chunk retrieval for all questions | |
retrieval_start = time.time() | |
all_chunks = set() | |
for i, question in enumerate(request.questions): | |
question_start = time.time() | |
top_chunks = retrieve_chunks(pinecone_index, texts, question) | |
question_time = time.time() - question_start | |
all_chunks.update(top_chunks) | |
retrieval_time = time.time() - retrieval_start | |
timing_data['chunk_retrieval'] = round(retrieval_time, 2) | |
print(f"Retrieved {len(all_chunks)} unique chunks") | |
# Time LLM processing | |
llm_start = time.time() | |
print(f"Processing all {len(request.questions)} questions in batch...") | |
response = query_gemini(request.questions, list(all_chunks)) | |
llm_time = time.time() - llm_start | |
timing_data['llm_processing'] = round(llm_time, 2) | |
# Time response processing | |
response_start = time.time() | |
# Extract answers from the JSON response | |
if isinstance(response, dict) and "answers" in response: | |
answers = response["answers"] | |
while len(answers) < len(request.questions): | |
answers.append("Not Found") | |
answers = answers[:len(request.questions)] | |
else: | |
answers = [response] if isinstance(response, str) else [] | |
while len(answers) < len(request.questions): | |
answers.append("Not Found") | |
answers = answers[:len(request.questions)] | |
response_time = time.time() - response_start | |
timing_data['response_processing'] = round(response_time, 2) | |
print(f"Generated {len(answers)} answers") | |
# Calculate total time | |
total_time = time.time() - start_time | |
timing_data['total_time'] = round(total_time, 2) | |
print(f"\n=== TIMING BREAKDOWN ===") | |
print(f"PDF Parsing: {timing_data['pdf_parsing']}s") | |
print(f"Pinecone Index Building: {timing_data['pinecone_index_building']}s") | |
print(f"Chunk Retrieval: {timing_data['chunk_retrieval']}s") | |
print(f"LLM Processing: {timing_data['llm_processing']}s") | |
print(f"Response Processing: {timing_data['response_processing']}s") | |
print(f"TOTAL TIME: {timing_data['total_time']}s") | |
print(f"=======================\n") | |
result = {"answers": answers} | |
print(f"=== OUTPUT JSON ===") | |
print(f"{result}") | |
print(f"==================\n") | |
return result | |
except Exception as e: | |
total_time = time.time() - start_time | |
print(f"Error after {total_time:.2f} seconds: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
async def run_local_query(request: LocalQueryRequest): | |
start_time = time.time() | |
timing_data = {} | |
try: | |
print(f"\n=== INPUT JSON ===") | |
print(f"Document Path: {request.document_path}") | |
print(f"Questions: {request.questions}") | |
print(f"==================\n") | |
print(f"Processing local document: {request.document_path}") | |
print(f"Processing {len(request.questions)} questions...") | |
# Time local PDF parsing | |
pdf_start = time.time() | |
text_chunks = parse_pdf_from_file(request.document_path) | |
pdf_time = time.time() - pdf_start | |
timing_data['pdf_parsing'] = round(pdf_time, 2) | |
print(f"Extracted {len(text_chunks)} text chunks from local PDF") | |
# Time Pinecone index building/upsert | |
index_start = time.time() | |
pinecone_index = build_pinecone_index(text_chunks) | |
index_time = time.time() - index_start | |
timing_data['pinecone_index_building'] = round(index_time, 2) | |
texts = text_chunks | |
# Time chunk retrieval for all questions | |
retrieval_start = time.time() | |
all_chunks = set() | |
for i, question in enumerate(request.questions): | |
question_start = time.time() | |
top_chunks = retrieve_chunks(pinecone_index, texts, question) | |
question_time = time.time() - question_start | |
all_chunks.update(top_chunks) | |
retrieval_time = time.time() - retrieval_start | |
timing_data['chunk_retrieval'] = round(retrieval_time, 2) | |
print(f"Retrieved {len(all_chunks)} unique chunks") | |
# Time LLM processing | |
llm_start = time.time() | |
print(f"Processing all {len(request.questions)} questions in batch...") | |
response = query_gemini(request.questions, list(all_chunks)) | |
llm_time = time.time() - llm_start | |
timing_data['llm_processing'] = round(llm_time, 2) | |
# Time response processing | |
response_start = time.time() | |
if isinstance(response, dict) and "answers" in response: | |
answers = response["answers"] | |
while len(answers) < len(request.questions): | |
answers.append("Not Found") | |
answers = answers[:len(request.questions)] | |
else: | |
answers = [response] if isinstance(response, str) else [] | |
while len(answers) < len(request.questions): | |
answers.append("Not Found") | |
answers = answers[:len(request.questions)] | |
response_time = time.time() - response_start | |
timing_data['response_processing'] = round(response_time, 2) | |
print(f"Generated {len(answers)} answers") | |
total_time = time.time() - start_time | |
timing_data['total_time'] = round(total_time, 2) | |
print(f"\n=== TIMING BREAKDOWN ===") | |
print(f"PDF Parsing: {timing_data['pdf_parsing']}s") | |
print(f"Pinecone Index Building: {timing_data['pinecone_index_building']}s") | |
print(f"Chunk Retrieval: {timing_data['chunk_retrieval']}s") | |
print(f"LLM Processing: {timing_data['llm_processing']}s") | |
print(f"Response Processing: {timing_data['response_processing']}s") | |
print(f"TOTAL TIME: {timing_data['total_time']}s") | |
print(f"=======================\n") | |
result = {"answers": answers} | |
print(f"=== OUTPUT JSON ===") | |
print(f"{result}") | |
print(f"==================\n") | |
return result | |
except Exception as e: | |
total_time = time.time() - start_time | |
print(f"Error after {total_time:.2f} seconds: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
if __name__ == "__main__": | |
port = int(os.environ.get("PORT", 7860)) | |
uvicorn.run("app:app", host="0.0.0.0", port=port) | |