Spaces:
Paused
Paused
""" | |
FastAPI server for OpenAI Realtime API integration with RAG system. | |
Provides endpoints for session management and RAG tool calls. | |
Directory structure: | |
/data/ # Original PDFs, HTML | |
/embeddings/ # FAISS, Chroma, DPR vector stores | |
/graph/ # Graph database files | |
/metadata/ # Image metadata (SQLite or MongoDB) | |
""" | |
import json | |
import logging | |
import os | |
import time | |
from typing import Dict, Any, Optional | |
from fastapi import FastAPI, HTTPException, Request, Response, status | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
from fastapi.exceptions import RequestValidationError | |
from starlette.exceptions import HTTPException as StarletteHTTPException | |
from pydantic import BaseModel | |
import uvicorn | |
from openai import OpenAI | |
# Import all query modules | |
from query_graph import query as graph_query | |
from query_vanilla import query as vanilla_query | |
from query_dpr import query as dpr_query | |
from query_bm25 import query as bm25_query | |
from query_context import query as context_query | |
from query_vision import query as vision_query | |
from config import OPENAI_API_KEY, OPENAI_CHAT_MODEL, OPENAI_REALTIME_MODEL, REALTIME_VOICE, REALTIME_INSTRUCTIONS, DEFAULT_METHOD | |
from analytics_db import log_query | |
logger = logging.getLogger(__name__) | |
# Initialize FastAPI app | |
app = FastAPI(title="SIGHT Realtime API Server", version="1.0.0") | |
# CORS middleware for frontend integration | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # In production, restrict to your domain | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def log_requests(request: Request, call_next): | |
"""Log all incoming requests for debugging.""" | |
logger.info(f"Incoming request: {request.method} {request.url}") | |
try: | |
response = await call_next(request) | |
logger.info(f"Response status: {response.status_code}") | |
return response | |
except Exception as e: | |
logger.error(f"Request processing error: {e}") | |
return JSONResponse( | |
content={"error": "Internal server error"}, | |
status_code=500 | |
) | |
# Exception handlers | |
async def validation_exception_handler(request: Request, exc: RequestValidationError): | |
logger.warning(f"Validation error for {request.url}: {exc}") | |
return JSONResponse( | |
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, | |
content={"error": "Invalid request format", "details": str(exc)} | |
) | |
async def http_exception_handler(request: Request, exc: StarletteHTTPException): | |
logger.warning(f"HTTP error for {request.url}: {exc.status_code} - {exc.detail}") | |
return JSONResponse( | |
status_code=exc.status_code, | |
content={"error": exc.detail} | |
) | |
async def general_exception_handler(request: Request, exc: Exception): | |
logger.error(f"Unhandled error for {request.url}: {exc}") | |
return JSONResponse( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
content={"error": "Internal server error"} | |
) | |
# Initialize OpenAI client | |
client = OpenAI(api_key=OPENAI_API_KEY) | |
# Query method dispatch | |
QUERY_DISPATCH = { | |
'graph': graph_query, | |
'vanilla': vanilla_query, | |
'dpr': dpr_query, | |
'bm25': bm25_query, | |
'context': context_query, | |
'vision': vision_query | |
} | |
# Use configuration from config.py with environment variable overrides | |
REALTIME_MODEL = os.getenv("REALTIME_MODEL", OPENAI_REALTIME_MODEL) | |
VOICE = os.getenv("REALTIME_VOICE", REALTIME_VOICE) | |
INSTRUCTIONS = os.getenv("REALTIME_INSTRUCTIONS", REALTIME_INSTRUCTIONS) | |
# Pydantic models for request/response | |
class SessionRequest(BaseModel): | |
"""Request model for creating ephemeral sessions.""" | |
model: Optional[str] = "gpt-4o-realtime-preview" | |
instructions: Optional[str] = None | |
voice: Optional[str] = None | |
class RAGRequest(BaseModel): | |
"""Request model for RAG queries.""" | |
query: str | |
method: str = "graph" | |
top_k: int = 5 | |
image_path: Optional[str] = None | |
class RAGResponse(BaseModel): | |
"""Response model for RAG queries.""" | |
answer: str | |
citations: list | |
method: str | |
citations_html: Optional[str] = None | |
async def create_ephemeral_session(request: SessionRequest) -> JSONResponse: | |
""" | |
Create an ephemeral session token for OpenAI Realtime API. | |
This token will be used by the frontend WebRTC client. | |
""" | |
try: | |
logger.info(f"Creating ephemeral session with model: {request.model or REALTIME_MODEL}") | |
# Create ephemeral token using direct HTTP call to OpenAI API | |
# Since the Python SDK doesn't support realtime sessions yet | |
import requests | |
session_data = { | |
"model": request.model or REALTIME_MODEL, | |
"voice": request.voice or VOICE, | |
"modalities": ["audio", "text"], | |
"instructions": request.instructions or INSTRUCTIONS, | |
} | |
headers = { | |
"Authorization": f"Bearer {OPENAI_API_KEY}", | |
"Content-Type": "application/json" | |
} | |
# Make direct HTTP request to OpenAI's realtime sessions endpoint | |
response = requests.post( | |
"https://api.openai.com/v1/realtime/sessions", | |
json=session_data, | |
headers=headers, | |
timeout=30 | |
) | |
if response.status_code == 200: | |
session_result = response.json() | |
response_data = { | |
"client_secret": session_result.get("client_secret", {}).get("value") or session_result.get("client_secret"), | |
"model": request.model or REALTIME_MODEL, | |
"session_id": session_result.get("id") | |
} | |
logger.info("Ephemeral session created successfully") | |
return JSONResponse(content=response_data, status_code=200) | |
else: | |
logger.error(f"OpenAI API error: {response.status_code} - {response.text}") | |
return JSONResponse( | |
content={"error": f"OpenAI API error: {response.status_code} - {response.text}"}, | |
status_code=response.status_code | |
) | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Network error creating ephemeral session: {e}") | |
return JSONResponse( | |
content={"error": f"Network error: {str(e)}"}, | |
status_code=500 | |
) | |
except Exception as e: | |
logger.error(f"Error creating ephemeral session: {e}") | |
return JSONResponse( | |
content={"error": f"Session creation failed: {str(e)}"}, | |
status_code=500 | |
) | |
async def rag_query(request: RAGRequest) -> RAGResponse: | |
""" | |
Handle RAG queries from the realtime interface. | |
This endpoint is called by the JavaScript frontend when the model | |
requests the ask_rag function. | |
""" | |
try: | |
logger.info(f"RAG query: {request.query} using method: {request.method}") | |
# Validate and default method if needed | |
method = request.method | |
if method not in QUERY_DISPATCH: | |
logger.warning(f"Invalid method '{method}', using default '{DEFAULT_METHOD}'") | |
method = DEFAULT_METHOD | |
# Get the appropriate query function | |
query_func = QUERY_DISPATCH[method] | |
# Execute the query | |
start_time = time.time() | |
answer, citations = query_func( | |
question=request.query, | |
image_path=request.image_path, | |
top_k=request.top_k | |
) | |
response_time = (time.time() - start_time) * 1000 # Convert to ms | |
# Format citations for HTML display (optional) | |
citations_html = format_citations_html(citations, method) | |
# Log to analytics database (mark as voice interaction) | |
try: | |
# Generate unique session ID for each voice interaction | |
import uuid | |
voice_session_id = f"voice_{uuid.uuid4().hex[:8]}" | |
log_query( | |
user_query=request.query, | |
method=method, | |
answer=answer, | |
citations=citations, | |
response_time=response_time, | |
image_path=request.image_path, | |
top_k=request.top_k, | |
session_id=voice_session_id, | |
additional_settings={'voice_interaction': True, 'interaction_type': 'speech_to_speech'} | |
) | |
logger.info(f"Voice interaction logged: {request.query[:50]}...") | |
except Exception as log_error: | |
logger.error(f"Failed to log voice query: {log_error}") | |
logger.info(f"RAG query completed: {len(answer)} chars, {len(citations)} citations") | |
return RAGResponse( | |
answer=answer, | |
citations=citations, | |
method=method, | |
citations_html=citations_html | |
) | |
except Exception as e: | |
logger.error(f"Error processing RAG query: {e}") | |
raise HTTPException(status_code=500, detail=f"RAG query failed: {str(e)}") | |
def format_citations_html(citations: list, method: str) -> str: | |
"""Format citations as HTML for display.""" | |
if not citations: | |
return "<p><em>No citations available</em></p>" | |
html_parts = ["<div style='margin-top: 1em;'><strong>Sources:</strong><ul>"] | |
for citation in citations: | |
if isinstance(citation, dict) and 'source' in citation: | |
source = citation['source'] | |
cite_type = citation.get('type', 'unknown') | |
# Build citation text based on type | |
if cite_type == 'pdf': | |
cite_text = f"π {source} (PDF)" | |
elif cite_type == 'html': | |
url = citation.get('url', '') | |
if url: | |
cite_text = f"π <a href='{url}' target='_blank'>{source}</a> (Web)" | |
else: | |
cite_text = f"π {source} (Web)" | |
elif cite_type == 'image': | |
page = citation.get('page', 'N/A') | |
cite_text = f"πΌοΈ {source} (Image, page {page})" | |
else: | |
cite_text = f"π {source}" | |
# Add scores if available | |
scores = [] | |
if 'relevance_score' in citation: | |
scores.append(f"relevance: {citation['relevance_score']:.3f}") | |
if 'score' in citation: | |
scores.append(f"score: {citation['score']:.3f}") | |
if scores: | |
cite_text += f" <small>({', '.join(scores)})</small>" | |
html_parts.append(f"<li>{cite_text}</li>") | |
elif isinstance(citation, (list, tuple)) and len(citation) >= 4: | |
# Handle legacy citation format (header, score, text, source) | |
header, score, text, source = citation[:4] | |
cite_text = f"π {source} <small>(score: {score:.3f})</small>" | |
html_parts.append(f"<li>{cite_text}</li>") | |
html_parts.append("</ul></div>") | |
return "".join(html_parts) | |
async def root(): | |
"""Root endpoint to prevent invalid HTTP request warnings.""" | |
return { | |
"service": "SIGHT Realtime API Server", | |
"version": "1.0.0", | |
"status": "running", | |
"endpoints": { | |
"session": "POST /session - Create realtime session", | |
"rag": "POST /rag - Query RAG system", | |
"health": "GET /health - Health check", | |
"methods": "GET /methods - List available RAG methods" | |
} | |
} | |
async def health_check(): | |
"""Health check endpoint.""" | |
return {"status": "healthy", "service": "SIGHT Realtime API Server"} | |
async def list_methods(): | |
"""List available RAG methods.""" | |
return { | |
"methods": list(QUERY_DISPATCH.keys()), | |
"descriptions": { | |
'graph': "Graph-based RAG using NetworkX with relationship-aware retrieval", | |
'vanilla': "Standard vector search with FAISS and OpenAI embeddings", | |
'dpr': "Dense Passage Retrieval with bi-encoder and cross-encoder re-ranking", | |
'bm25': "BM25 keyword search with neural re-ranking for exact term matching", | |
'context': "Context stuffing with full document loading and heuristic selection", | |
'vision': "Vision-based search using GPT-5 Vision for image analysis" | |
} | |
} | |
async def options_handler(request: Request, response: Response): | |
"""Handle CORS preflight requests.""" | |
response.headers["Access-Control-Allow-Origin"] = "*" | |
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS" | |
response.headers["Access-Control-Allow-Headers"] = "*" | |
return response | |
if __name__ == "__main__": | |
import argparse | |
# Parse command line arguments | |
parser = argparse.ArgumentParser(description="SIGHT Realtime API Server") | |
parser.add_argument("--https", action="store_true", help="Enable HTTPS with self-signed certificate") | |
parser.add_argument("--port", type=int, default=5050, help="Port to run the server on") | |
parser.add_argument("--host", default="0.0.0.0", help="Host to bind the server to") | |
args = parser.parse_args() | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
) | |
# Suppress uvicorn access logs for cleaner output | |
uvicorn_logger = logging.getLogger("uvicorn.access") | |
uvicorn_logger.setLevel(logging.WARNING) | |
# Prepare uvicorn configuration | |
uvicorn_config = { | |
"app": "realtime_server:app", | |
"host": args.host, | |
"port": args.port, | |
"reload": True, | |
"log_level": "warning", | |
"access_log": False | |
} | |
# Add SSL configuration if HTTPS is requested | |
if args.https: | |
logger.info("Starting server with HTTPS (self-signed certificate)") | |
logger.warning("β οΈ Self-signed certificate will show security warnings in browser") | |
logger.info("For production, use a proper SSL certificate from a CA") | |
# Note: You would need to generate SSL certificates | |
# For development, you can create self-signed certificates: | |
# openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes | |
uvicorn_config.update({ | |
"ssl_keyfile": "key.pem", | |
"ssl_certfile": "cert.pem" | |
}) | |
print(f"π Starting HTTPS server on https://{args.host}:{args.port}") | |
print("π To generate self-signed certificates, run:") | |
print(" openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes") | |
else: | |
print(f"π Starting HTTP server on http://{args.host}:{args.port}") | |
print("β οΈ HTTP only works for localhost. Use --https for production deployment.") | |
# Run the server | |
uvicorn.run(**uvicorn_config) | |