Spaces:
Sleeping
Sleeping
from fastapi import APIRouter, HTTPException, Query, Request, BackgroundTasks | |
from pydantic import BaseModel | |
from services.ip_utils import get_client_ip | |
from services.db_logger import log_query | |
from services.embedder import build_faiss_index | |
from services.retriever import retrieve_chunks | |
from services.llm_service import query_gemini,query_openai | |
from Extraction_Models import parse_document_url, parse_document_file | |
from threading import Lock | |
import hashlib, time | |
from concurrent.futures import ThreadPoolExecutor | |
router = APIRouter() | |
class QueryRequest(BaseModel): | |
url: str | |
questions: list[str] | |
class LocalQueryRequest(BaseModel): | |
document_path: str | |
questions: list[str] | |
def get_document_id(url: str): | |
return hashlib.md5(url.encode()).hexdigest() | |
doc_cache = {} | |
doc_cache_lock = Lock() | |
async def clear_cache(doc_id: str = Query(None), url: str = Query(None), doc_only: bool = Query(False)): | |
cleared = {} | |
if url: | |
doc_id = get_document_id(url) | |
if doc_id: | |
with doc_cache_lock: | |
if doc_id in doc_cache: | |
del doc_cache[doc_id] | |
cleared["doc_cache"] = f"Cleared document {doc_id}" | |
else: | |
with doc_cache_lock: | |
doc_cache.clear() | |
cleared["doc_cache"] = "Cleared ALL documents" | |
return {"status": "success", "cleared": cleared} | |
def print_timings(timings: dict): | |
print("\n=== TIMINGS ===") | |
for k, v in timings.items(): | |
if isinstance(v, float): | |
print(f"[TIMER] {k}: {v:.4f}s") | |
elif isinstance(v, list): | |
print(f"[TIMER] {k}: {', '.join(f'{x:.4f}s' for x in v)}") | |
else: | |
print(f"[TIMER] {k}: {v}") | |
print("================\n") | |
async def run_query(request: QueryRequest, fastapi_request: Request, background_tasks: BackgroundTasks): | |
timings = {} | |
try: | |
user_ip = get_client_ip(fastapi_request) | |
user_agent = fastapi_request.headers.get("user-agent", "Unknown") | |
doc_id = get_document_id(request.url) | |
print("Input :",request.url,request.questions) | |
# Parsing | |
t_parse_start = time.time() | |
with doc_cache_lock: | |
if doc_id in doc_cache: | |
cached = doc_cache[doc_id] | |
text_chunks, index, texts = cached["chunks"], cached["index"], cached["texts"] | |
timings["parse_time"] = 0 | |
timings["index_time"] = 0 | |
else: | |
text_chunks = parse_document_url(request.url) | |
t_parse_end = time.time() | |
timings["parse_time"] = t_parse_end - t_parse_start | |
# Indexing | |
t_index_start = time.time() | |
index, texts = build_faiss_index(text_chunks) | |
t_index_end = time.time() | |
timings["index_time"] = t_index_end - t_index_start | |
doc_cache[doc_id] = {"chunks": text_chunks, "index": index, "texts": texts} | |
timings["cache_check_time"] = time.time() - t_parse_start | |
# Retrieval | |
t_retrieve_start = time.time() | |
all_chunks = set() | |
for question in request.questions: | |
all_chunks.update(retrieve_chunks(index, texts, question)) | |
context_chunks = list(all_chunks) | |
timings["retrieval_time"] = time.time() - t_retrieve_start | |
# LLM query | |
t_llm_start = time.time() | |
batch_size = 10 | |
results_dict = {} | |
llm_batch_timings = [] | |
with ThreadPoolExecutor(max_workers=5) as executor: | |
futures = [] | |
for i in range(0, len(request.questions), batch_size): | |
batch = request.questions[i:i + batch_size] | |
futures.append(executor.submit(query_openai, batch, context_chunks)) | |
for i, future in enumerate(futures): | |
t_batch_start = time.time() | |
result = future.result() | |
t_batch_end = time.time() | |
llm_batch_timings.append(t_batch_end - t_batch_start) | |
if "answers" in result: | |
for j, ans in enumerate(result["answers"]): | |
results_dict[i * batch_size + j] = ans | |
timings["llm_time"] = time.time() - t_llm_start | |
timings["llm_batch_times"] = llm_batch_timings | |
responses = [results_dict.get(i, "Not Found") for i in range(len(request.questions))] | |
# Logging | |
total_float_time = sum(v for v in timings.values() if isinstance(v, (int, float))) | |
for q, a in zip(request.questions, responses): | |
background_tasks.add_task(log_query, request.url, q, a, user_ip, total_float_time, user_agent) | |
# Print timings in console | |
print_timings(timings) | |
# Return ONLY answers | |
print("answers : ",responses) | |
return {"answers": responses} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Internal server error: {e}") | |
async def run_local_query(request: LocalQueryRequest, fastapi_request: Request, background_tasks: BackgroundTasks): | |
timings = {} | |
try: | |
user_ip = get_client_ip(fastapi_request) | |
user_agent = fastapi_request.headers.get("user-agent", "Unknown") | |
# Parsing | |
t_parse_start = time.time() | |
text_chunks = parse_document_file(request.document_path) | |
t_parse_end = time.time() | |
timings["parse_time"] = t_parse_end - t_parse_start | |
# Indexing | |
t_index_start = time.time() | |
index, texts = build_faiss_index(text_chunks) | |
t_index_end = time.time() | |
timings["index_time"] = t_index_end - t_index_start | |
# Retrieval | |
t_retrieve_start = time.time() | |
all_chunks = set() | |
for question in request.questions: | |
all_chunks.update(retrieve_chunks(index, texts, question)) | |
context_chunks = list(all_chunks) | |
timings["retrieval_time"] = time.time() - t_retrieve_start | |
# LLM query | |
t_llm_start = time.time() | |
batch_size = 20 | |
results_dict = {} | |
llm_batch_timings = [] | |
with ThreadPoolExecutor(max_workers=5) as executor: | |
futures = [] | |
for i in range(0, len(request.questions), batch_size): | |
batch = request.questions[i:i + batch_size] | |
futures.append(executor.submit(query_gemini, batch, context_chunks)) | |
for i, future in enumerate(futures): | |
t_batch_start = time.time() | |
result = future.result() | |
t_batch_end = time.time() | |
llm_batch_timings.append(t_batch_end - t_batch_start) | |
if "answers" in result: | |
for j, ans in enumerate(result["answers"]): | |
results_dict[i * batch_size + j] = ans | |
timings["llm_time"] = time.time() - t_llm_start | |
timings["llm_batch_times"] = llm_batch_timings | |
responses = [results_dict.get(i, "Not Found") for i in range(len(request.questions))] | |
# Logging | |
total_float_time = sum(v for v in timings.values() if isinstance(v, (int, float))) | |
for q, a in zip(request.questions, responses): | |
background_tasks.add_task(log_query, request.document_path, q, a, user_ip, total_float_time, user_agent) | |
# Print timings in console | |
print_timings(timings) | |
# Return ONLY answers | |
return {"answers": responses} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Internal server error: {e}") | |
async def run_query_openai(request: QueryRequest, fastapi_request: Request, background_tasks: BackgroundTasks): | |
timings = {} | |
try: | |
user_ip = get_client_ip(fastapi_request) | |
user_agent = fastapi_request.headers.get("user-agent", "Unknown") | |
doc_id = get_document_id(request.url) | |
# Parsing | |
t_parse_start = time.time() | |
with doc_cache_lock: | |
if doc_id in doc_cache: | |
cached = doc_cache[doc_id] | |
text_chunks, index, texts = cached["chunks"], cached["index"], cached["texts"] | |
timings["parse_time"] = 0 | |
timings["index_time"] = 0 | |
else: | |
text_chunks = parse_document_url(request.url) | |
t_parse_end = time.time() | |
timings["parse_time"] = t_parse_end - t_parse_start | |
# Indexing | |
t_index_start = time.time() | |
index, texts = build_faiss_index(text_chunks) | |
t_index_end = time.time() | |
timings["index_time"] = t_index_end - t_index_start | |
doc_cache[doc_id] = {"chunks": text_chunks, "index": index, "texts": texts} | |
timings["cache_check_time"] = time.time() - t_parse_start | |
# Retrieval | |
t_retrieve_start = time.time() | |
all_chunks = set() | |
for question in request.questions: | |
all_chunks.update(retrieve_chunks(index, texts, question)) | |
context_chunks = list(all_chunks) | |
timings["retrieval_time"] = time.time() - t_retrieve_start | |
# OpenAI LLM query | |
t_llm_start = time.time() | |
batch_size = 10 | |
results_dict = {} | |
llm_batch_timings = [] | |
with ThreadPoolExecutor(max_workers=5) as executor: | |
futures = [] | |
for i in range(0, len(request.questions), batch_size): | |
batch = request.questions[i:i + batch_size] | |
futures.append(executor.submit(query_gemini, batch, context_chunks)) | |
for i, future in enumerate(futures): | |
t_batch_start = time.time() | |
result = future.result() | |
t_batch_end = time.time() | |
llm_batch_timings.append(t_batch_end - t_batch_start) | |
if "answers" in result: | |
for j, ans in enumerate(result["answers"]): | |
results_dict[i * batch_size + j] = ans | |
timings["llm_time"] = time.time() - t_llm_start | |
timings["llm_batch_times"] = llm_batch_timings | |
responses = [results_dict.get(i, "Not Found") for i in range(len(request.questions))] | |
# Logging | |
total_float_time = sum(v for v in timings.values() if isinstance(v, (int, float))) | |
for q, a in zip(request.questions, responses): | |
background_tasks.add_task(log_query, request.url, q, a, user_ip, total_float_time, user_agent) | |
# Print timings in console | |
print_timings(timings) | |
return {"answers": responses} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Internal server error: {e}") | |