rag-hackrx / api /routes.py
Rivalcoder
Add Checking of Api
a7933c3
from fastapi import APIRouter, HTTPException, Query, Request, BackgroundTasks
from pydantic import BaseModel
from services.ip_utils import get_client_ip
from services.db_logger import log_query
from services.embedder import build_faiss_index
from services.retriever import retrieve_chunks
from services.llm_service import query_gemini,query_openai
from Extraction_Models import parse_document_url, parse_document_file
from threading import Lock
import hashlib, time
from concurrent.futures import ThreadPoolExecutor
router = APIRouter()
class QueryRequest(BaseModel):
url: str
questions: list[str]
class LocalQueryRequest(BaseModel):
document_path: str
questions: list[str]
def get_document_id(url: str):
return hashlib.md5(url.encode()).hexdigest()
doc_cache = {}
doc_cache_lock = Lock()
@router.delete("/cache/clear")
async def clear_cache(doc_id: str = Query(None), url: str = Query(None), doc_only: bool = Query(False)):
cleared = {}
if url:
doc_id = get_document_id(url)
if doc_id:
with doc_cache_lock:
if doc_id in doc_cache:
del doc_cache[doc_id]
cleared["doc_cache"] = f"Cleared document {doc_id}"
else:
with doc_cache_lock:
doc_cache.clear()
cleared["doc_cache"] = "Cleared ALL documents"
return {"status": "success", "cleared": cleared}
def print_timings(timings: dict):
print("\n=== TIMINGS ===")
for k, v in timings.items():
if isinstance(v, float):
print(f"[TIMER] {k}: {v:.4f}s")
elif isinstance(v, list):
print(f"[TIMER] {k}: {', '.join(f'{x:.4f}s' for x in v)}")
else:
print(f"[TIMER] {k}: {v}")
print("================\n")
@router.post("/hackrx/run")
async def run_query(request: QueryRequest, fastapi_request: Request, background_tasks: BackgroundTasks):
timings = {}
try:
user_ip = get_client_ip(fastapi_request)
user_agent = fastapi_request.headers.get("user-agent", "Unknown")
doc_id = get_document_id(request.url)
print("Input :",request.url,request.questions)
# Parsing
t_parse_start = time.time()
with doc_cache_lock:
if doc_id in doc_cache:
cached = doc_cache[doc_id]
text_chunks, index, texts = cached["chunks"], cached["index"], cached["texts"]
timings["parse_time"] = 0
timings["index_time"] = 0
else:
text_chunks = parse_document_url(request.url)
t_parse_end = time.time()
timings["parse_time"] = t_parse_end - t_parse_start
# Indexing
t_index_start = time.time()
index, texts = build_faiss_index(text_chunks)
t_index_end = time.time()
timings["index_time"] = t_index_end - t_index_start
doc_cache[doc_id] = {"chunks": text_chunks, "index": index, "texts": texts}
timings["cache_check_time"] = time.time() - t_parse_start
# Retrieval
t_retrieve_start = time.time()
all_chunks = set()
for question in request.questions:
all_chunks.update(retrieve_chunks(index, texts, question))
context_chunks = list(all_chunks)
timings["retrieval_time"] = time.time() - t_retrieve_start
# LLM query
t_llm_start = time.time()
batch_size = 10
results_dict = {}
llm_batch_timings = []
with ThreadPoolExecutor(max_workers=5) as executor:
futures = []
for i in range(0, len(request.questions), batch_size):
batch = request.questions[i:i + batch_size]
futures.append(executor.submit(query_openai, batch, context_chunks))
for i, future in enumerate(futures):
t_batch_start = time.time()
result = future.result()
t_batch_end = time.time()
llm_batch_timings.append(t_batch_end - t_batch_start)
if "answers" in result:
for j, ans in enumerate(result["answers"]):
results_dict[i * batch_size + j] = ans
timings["llm_time"] = time.time() - t_llm_start
timings["llm_batch_times"] = llm_batch_timings
responses = [results_dict.get(i, "Not Found") for i in range(len(request.questions))]
# Logging
total_float_time = sum(v for v in timings.values() if isinstance(v, (int, float)))
for q, a in zip(request.questions, responses):
background_tasks.add_task(log_query, request.url, q, a, user_ip, total_float_time, user_agent)
# Print timings in console
print_timings(timings)
# Return ONLY answers
print("answers : ",responses)
return {"answers": responses}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
@router.post("/hackrx/local")
async def run_local_query(request: LocalQueryRequest, fastapi_request: Request, background_tasks: BackgroundTasks):
timings = {}
try:
user_ip = get_client_ip(fastapi_request)
user_agent = fastapi_request.headers.get("user-agent", "Unknown")
# Parsing
t_parse_start = time.time()
text_chunks = parse_document_file(request.document_path)
t_parse_end = time.time()
timings["parse_time"] = t_parse_end - t_parse_start
# Indexing
t_index_start = time.time()
index, texts = build_faiss_index(text_chunks)
t_index_end = time.time()
timings["index_time"] = t_index_end - t_index_start
# Retrieval
t_retrieve_start = time.time()
all_chunks = set()
for question in request.questions:
all_chunks.update(retrieve_chunks(index, texts, question))
context_chunks = list(all_chunks)
timings["retrieval_time"] = time.time() - t_retrieve_start
# LLM query
t_llm_start = time.time()
batch_size = 20
results_dict = {}
llm_batch_timings = []
with ThreadPoolExecutor(max_workers=5) as executor:
futures = []
for i in range(0, len(request.questions), batch_size):
batch = request.questions[i:i + batch_size]
futures.append(executor.submit(query_gemini, batch, context_chunks))
for i, future in enumerate(futures):
t_batch_start = time.time()
result = future.result()
t_batch_end = time.time()
llm_batch_timings.append(t_batch_end - t_batch_start)
if "answers" in result:
for j, ans in enumerate(result["answers"]):
results_dict[i * batch_size + j] = ans
timings["llm_time"] = time.time() - t_llm_start
timings["llm_batch_times"] = llm_batch_timings
responses = [results_dict.get(i, "Not Found") for i in range(len(request.questions))]
# Logging
total_float_time = sum(v for v in timings.values() if isinstance(v, (int, float)))
for q, a in zip(request.questions, responses):
background_tasks.add_task(log_query, request.document_path, q, a, user_ip, total_float_time, user_agent)
# Print timings in console
print_timings(timings)
# Return ONLY answers
return {"answers": responses}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
@router.post("/hackrx/run_openai")
async def run_query_openai(request: QueryRequest, fastapi_request: Request, background_tasks: BackgroundTasks):
timings = {}
try:
user_ip = get_client_ip(fastapi_request)
user_agent = fastapi_request.headers.get("user-agent", "Unknown")
doc_id = get_document_id(request.url)
# Parsing
t_parse_start = time.time()
with doc_cache_lock:
if doc_id in doc_cache:
cached = doc_cache[doc_id]
text_chunks, index, texts = cached["chunks"], cached["index"], cached["texts"]
timings["parse_time"] = 0
timings["index_time"] = 0
else:
text_chunks = parse_document_url(request.url)
t_parse_end = time.time()
timings["parse_time"] = t_parse_end - t_parse_start
# Indexing
t_index_start = time.time()
index, texts = build_faiss_index(text_chunks)
t_index_end = time.time()
timings["index_time"] = t_index_end - t_index_start
doc_cache[doc_id] = {"chunks": text_chunks, "index": index, "texts": texts}
timings["cache_check_time"] = time.time() - t_parse_start
# Retrieval
t_retrieve_start = time.time()
all_chunks = set()
for question in request.questions:
all_chunks.update(retrieve_chunks(index, texts, question))
context_chunks = list(all_chunks)
timings["retrieval_time"] = time.time() - t_retrieve_start
# OpenAI LLM query
t_llm_start = time.time()
batch_size = 10
results_dict = {}
llm_batch_timings = []
with ThreadPoolExecutor(max_workers=5) as executor:
futures = []
for i in range(0, len(request.questions), batch_size):
batch = request.questions[i:i + batch_size]
futures.append(executor.submit(query_gemini, batch, context_chunks))
for i, future in enumerate(futures):
t_batch_start = time.time()
result = future.result()
t_batch_end = time.time()
llm_batch_timings.append(t_batch_end - t_batch_start)
if "answers" in result:
for j, ans in enumerate(result["answers"]):
results_dict[i * batch_size + j] = ans
timings["llm_time"] = time.time() - t_llm_start
timings["llm_batch_times"] = llm_batch_timings
responses = [results_dict.get(i, "Not Found") for i in range(len(request.questions))]
# Logging
total_float_time = sum(v for v in timings.values() if isinstance(v, (int, float)))
for q, a in zip(request.questions, responses):
background_tasks.add_task(log_query, request.url, q, a, user_ip, total_float_time, user_agent)
# Print timings in console
print_timings(timings)
return {"answers": responses}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal server error: {e}")