mtyrrell's picture
ts startup
c245449
#CHATFED_ORCHESTRATOR
import gradio as gr
from fastapi import FastAPI, UploadFile, File, Form
from langserve import add_routes
from langgraph.graph import StateGraph, START, END
from typing import Optional, Dict, Any, List
from typing_extensions import TypedDict
from pydantic import BaseModel
from gradio_client import Client, file
import uvicorn
import os
from datetime import datetime
import logging
from contextlib import asynccontextmanager
import threading
from langchain_core.runnables import RunnableLambda
import tempfile
from utils import getconfig
config = getconfig("params.cfg")
RETRIEVER = config.get("retriever", "RETRIEVER", fallback="https://giz-chatfed-retriever.hf.space")
GENERATOR = config.get("generator", "GENERATOR", fallback="https://giz-chatfed-generator.hf.space")
INGESTOR = config.get("ingestor", "INGESTOR", fallback="https://mtyrrell-chatfed-ingestor.hf.space")
MAX_CONTEXT_CHARS = config.get("general", "MAX_CONTEXT_CHARS")
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Models
class GraphState(TypedDict):
query: str
context: str
ingestor_context: str
result: str
reports_filter: str
sources_filter: str
subtype_filter: str
year_filter: str
file_content: Optional[bytes]
filename: Optional[str]
metadata: Optional[Dict[str, Any]]
class ChatFedInput(TypedDict):
query: str
reports_filter: Optional[str]
sources_filter: Optional[str]
subtype_filter: Optional[str]
year_filter: Optional[str]
session_id: Optional[str]
user_id: Optional[str]
file_content: Optional[bytes]
filename: Optional[str]
class ChatFedOutput(TypedDict):
result: str
metadata: Dict[str, Any]
class ChatUIInput(BaseModel):
text: str
# Module functions
def ingest_node(state: GraphState) -> GraphState:
"""Process file through ingestor if file is provided"""
start_time = datetime.now()
# If no file provided, skip this step
if not state.get("file_content") or not state.get("filename"):
logger.info("No file provided, skipping ingestion")
return {"ingestor_context": "", "metadata": state.get("metadata", {})}
logger.info(f"Ingesting file: {state['filename']}")
try:
client = Client(INGESTOR)
# Create a temporary file to upload
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(state["filename"])[1]) as tmp_file:
tmp_file.write(state["file_content"])
tmp_file_path = tmp_file.name
try:
# Call the ingestor's ingest endpoint - use gradio_client.file() for proper formatting
ingestor_context = client.predict(
file(tmp_file_path), # Use gradio_client.file() to properly format
api_name="/ingest"
)
logger.info(f"Ingest result length: {len(ingestor_context) if ingestor_context else 0}")
# Handle error cases
if isinstance(ingestor_context, str) and ingestor_context.startswith("Error:"):
raise Exception(ingestor_context)
finally:
# Clean up temporary file
os.unlink(tmp_file_path)
duration = (datetime.now() - start_time).total_seconds()
metadata = state.get("metadata", {})
metadata.update({
"ingestion_duration": duration,
"ingestor_context_length": len(ingestor_context) if ingestor_context else 0,
"ingestion_success": True
})
return {
"ingestor_context": ingestor_context,
"metadata": metadata
}
except Exception as e:
duration = (datetime.now() - start_time).total_seconds()
logger.error(f"Ingestion failed: {str(e)}")
metadata = state.get("metadata", {})
metadata.update({
"ingestion_duration": duration,
"ingestion_success": False,
"ingestion_error": str(e)
})
return {"ingestor_context": "", "metadata": metadata}
try:
client = Client(INGESTOR)
# Create a temporary file to upload
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(state["filename"])[1]) as tmp_file:
tmp_file.write(state["file_content"])
tmp_file_path = tmp_file.name
try:
# Call the ingestor's ingest endpoint - returns context directly
ingestor_context = client.predict(
file=tmp_file_path,
api_name="/ingest"
)
logger.info(f"Ingest result length: {len(ingestor_context) if ingestor_context else 0}")
finally:
# Clean up temporary file
os.unlink(tmp_file_path)
duration = (datetime.now() - start_time).total_seconds()
metadata = state.get("metadata", {})
metadata.update({
"ingestion_duration": duration,
"ingestor_context_length": len(ingestor_context) if ingestor_context else 0,
"ingestion_success": True
})
return {
"ingestor_context": ingestor_context,
"metadata": metadata
}
except Exception as e:
duration = (datetime.now() - start_time).total_seconds()
logger.error(f"Ingestion failed: {str(e)}")
metadata = state.get("metadata", {})
metadata.update({
"ingestion_duration": duration,
"ingestion_success": False,
"ingestion_error": str(e)
})
return {"ingestor_context": "", "metadata": metadata}
def retrieve_node(state: GraphState) -> GraphState:
start_time = datetime.now()
logger.info(f"Retrieval: {state['query'][:50]}...")
try:
client = Client(RETRIEVER)
context = client.predict(
query=state["query"],
reports_filter=state.get("reports_filter", ""),
sources_filter=state.get("sources_filter", ""),
subtype_filter=state.get("subtype_filter", ""),
year_filter=state.get("year_filter", ""),
api_name="/retrieve"
)
duration = (datetime.now() - start_time).total_seconds()
metadata = state.get("metadata", {})
metadata.update({
"retrieval_duration": duration,
"context_length": len(context) if context else 0,
"retrieval_success": True
})
return {"context": context, "metadata": metadata}
except Exception as e:
duration = (datetime.now() - start_time).total_seconds()
logger.error(f"Retrieval failed: {str(e)}")
metadata = state.get("metadata", {})
metadata.update({
"retrieval_duration": duration,
"retrieval_success": False,
"retrieval_error": str(e)
})
return {"context": "", "metadata": metadata}
def generate_node(state: GraphState) -> GraphState:
start_time = datetime.now()
logger.info(f"Generation: {state['query'][:50]}...")
try:
# Combine retriever context with ingestor context
retrieved_context = state.get("context", "")
ingestor_context = state.get("ingestor_context", "")
# Limit context size to prevent token overflow
MAX_CONTEXT_CHARS = int(config.get("general", "MAX_CONTEXT_CHARS"))
combined_context = ""
if ingestor_context and retrieved_context:
# Prioritize ingestor context, truncate if needed
ingestor_truncated = ingestor_context[:MAX_CONTEXT_CHARS//2] if len(ingestor_context) > MAX_CONTEXT_CHARS//2 else ingestor_context
retrieved_truncated = retrieved_context[:MAX_CONTEXT_CHARS//2] if len(retrieved_context) > MAX_CONTEXT_CHARS//2 else retrieved_context
combined_context = f"=== UPLOADED DOCUMENT CONTEXT ===\n{ingestor_truncated}\n\n=== RETRIEVED CONTEXT ===\n{retrieved_truncated}"
elif ingestor_context:
ingestor_truncated = ingestor_context[:MAX_CONTEXT_CHARS] if len(ingestor_context) > MAX_CONTEXT_CHARS else ingestor_context
combined_context = f"=== UPLOADED DOCUMENT CONTEXT ===\n{ingestor_truncated}"
elif retrieved_context:
combined_context = retrieved_context[:MAX_CONTEXT_CHARS] if len(retrieved_context) > MAX_CONTEXT_CHARS else retrieved_context
client = Client(GENERATOR)
result = client.predict(
query=state["query"],
context=combined_context,
api_name="/generate"
)
duration = (datetime.now() - start_time).total_seconds()
metadata = state.get("metadata", {})
metadata.update({
"generation_duration": duration,
"result_length": len(result) if result else 0,
"combined_context_length": len(combined_context),
"generation_success": True
})
return {"result": result, "metadata": metadata}
except Exception as e:
duration = (datetime.now() - start_time).total_seconds()
logger.error(f"Generation failed: {str(e)}")
metadata = state.get("metadata", {})
metadata.update({
"generation_duration": duration,
"generation_success": False,
"generation_error": str(e)
})
return {"result": f"Error: {str(e)}", "metadata": metadata}
# Updated graph with ingest node
workflow = StateGraph(GraphState)
workflow.add_node("ingest", ingest_node)
workflow.add_node("retrieve", retrieve_node)
workflow.add_node("generate", generate_node)
workflow.add_edge(START, "ingest")
workflow.add_edge("ingest", "retrieve")
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)
compiled_graph = workflow.compile()
def process_query_core(
query: str,
reports_filter: str = "",
sources_filter: str = "",
subtype_filter: str = "",
year_filter: str = "",
session_id: Optional[str] = None,
user_id: Optional[str] = None,
file_content: Optional[bytes] = None,
filename: Optional[str] = None,
return_metadata: bool = False
):
start_time = datetime.now()
if not session_id:
session_id = f"session_{start_time.strftime('%Y%m%d_%H%M%S')}"
try:
initial_state = {
"query": query,
"context": "",
"ingestor_context": "",
"result": "",
"reports_filter": reports_filter or "",
"sources_filter": sources_filter or "",
"subtype_filter": subtype_filter or "",
"year_filter": year_filter or "",
"file_content": file_content,
"filename": filename,
"metadata": {
"session_id": session_id,
"user_id": user_id,
"start_time": start_time.isoformat(),
"has_file_attachment": file_content is not None
}
}
final_state = compiled_graph.invoke(initial_state)
total_duration = (datetime.now() - start_time).total_seconds()
final_metadata = final_state.get("metadata", {})
final_metadata.update({
"total_duration": total_duration,
"end_time": datetime.now().isoformat(),
"pipeline_success": True
})
if return_metadata:
return {"result": final_state["result"], "metadata": final_metadata}
else:
return final_state["result"]
except Exception as e:
total_duration = (datetime.now() - start_time).total_seconds()
logger.error(f"Pipeline failed: {str(e)}")
if return_metadata:
error_metadata = {
"session_id": session_id,
"total_duration": total_duration,
"pipeline_success": False,
"error": str(e)
}
return {"result": f"Error: {str(e)}", "metadata": error_metadata}
else:
return f"Error: {str(e)}"
def process_query_gradio(query: str, file_upload, reports_filter: str = "", sources_filter: str = "",
subtype_filter: str = "", year_filter: str = "") -> str:
"""Gradio interface function with file upload support"""
file_content = None
filename = None
if file_upload is not None:
try:
with open(file_upload.name, 'rb') as f:
file_content = f.read()
filename = os.path.basename(file_upload.name)
logger.info(f"File uploaded: {filename}, size: {len(file_content)} bytes")
except Exception as e:
logger.error(f"Error reading uploaded file: {str(e)}")
return f"Error reading file: {str(e)}"
return process_query_core(
query=query,
reports_filter=reports_filter,
sources_filter=sources_filter,
subtype_filter=subtype_filter,
year_filter=year_filter,
file_content=file_content,
filename=filename,
session_id=f"gradio_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
return_metadata=False
)
def chatui_adapter(data) -> str:
try:
# Handle both dict and Pydantic model input
if hasattr(data, 'text'):
text = data.text
elif isinstance(data, dict) and 'text' in data:
text = data['text']
else:
logger.error(f"Unexpected input structure: {data}")
return "Error: Invalid input format. Expected 'text' field."
result = process_query_core(
query=text,
session_id=f"chatui_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
return_metadata=False
)
return result
except Exception as e:
logger.error(f"ChatUI error: {str(e)}")
return f"Error: {str(e)}"
def process_query_langserve(input_data: ChatFedInput) -> ChatFedOutput:
result = process_query_core(
query=input_data["query"],
reports_filter=input_data.get("reports_filter", ""),
sources_filter=input_data.get("sources_filter", ""),
subtype_filter=input_data.get("subtype_filter", ""),
year_filter=input_data.get("year_filter", ""),
session_id=input_data.get("session_id"),
user_id=input_data.get("user_id"),
file_content=input_data.get("file_content"),
filename=input_data.get("filename"),
return_metadata=True
)
return ChatFedOutput(result=result["result"], metadata=result["metadata"])
def create_gradio_interface():
with gr.Blocks(title="ChatFed Orchestrator") as demo:
gr.Markdown("# ChatFed Orchestrator")
gr.Markdown("Upload documents (PDF/DOCX) alongside your queries for enhanced context. MCP endpoints available at `/gradio_api/mcp/sse`")
with gr.Row():
with gr.Column():
query_input = gr.Textbox(label="Query", lines=2, placeholder="Enter your question...")
file_input = gr.File(label="Upload Document (PDF/DOCX)", file_types=[".pdf", ".docx"])
with gr.Accordion("Filters (Optional)", open=False):
reports_filter_input = gr.Textbox(label="Reports Filter", placeholder="e.g., annual_reports")
sources_filter_input = gr.Textbox(label="Sources Filter", placeholder="e.g., internal")
subtype_filter_input = gr.Textbox(label="Subtype Filter", placeholder="e.g., financial")
year_filter_input = gr.Textbox(label="Year Filter", placeholder="e.g., 2024")
submit_btn = gr.Button("Submit", variant="primary")
with gr.Column():
output = gr.Textbox(label="Response", lines=15, show_copy_button=True)
submit_btn.click(
fn=process_query_gradio,
inputs=[query_input, file_input, reports_filter_input, sources_filter_input,
subtype_filter_input, year_filter_input],
outputs=output
)
return demo
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("ChatFed Orchestrator starting up...")
yield
logger.info("Orchestrator shutting down...")
app = FastAPI(
title="ChatFed Orchestrator",
version="1.0.0",
lifespan=lifespan,
docs_url=None,
redoc_url=None
)
@app.get("/health")
async def health_check():
return {"status": "healthy"}
@app.get("/")
async def root():
return {
"message": "ChatFed Orchestrator API",
"endpoints": {
"health": "/health",
"chatfed": "/chatfed",
"chatfed-ui-stream": "/chatfed-ui-stream",
"chatfed-with-file": "/chatfed-with-file"
}
}
# Additional endpoint for file uploads via API
@app.post("/chatfed-with-file")
async def chatfed_with_file(
query: str = Form(...),
file: Optional[UploadFile] = File(None),
reports_filter: Optional[str] = Form(""),
sources_filter: Optional[str] = Form(""),
subtype_filter: Optional[str] = Form(""),
year_filter: Optional[str] = Form(""),
session_id: Optional[str] = Form(None),
user_id: Optional[str] = Form(None)
):
"""Endpoint for queries with optional file attachments"""
file_content = None
filename = None
if file:
file_content = await file.read()
filename = file.filename
result = process_query_core(
query=query,
reports_filter=reports_filter,
sources_filter=sources_filter,
subtype_filter=subtype_filter,
year_filter=year_filter,
file_content=file_content,
filename=filename,
session_id=session_id,
user_id=user_id,
return_metadata=True
)
return ChatFedOutput(result=result["result"], metadata=result["metadata"])
# Additional endpoint for file uploads via API
@app.post("/chatfed-with-file")
async def chatfed_with_file(
query: str = Form(...),
file: Optional[UploadFile] = File(None),
reports_filter: Optional[str] = Form(""),
sources_filter: Optional[str] = Form(""),
subtype_filter: Optional[str] = Form(""),
year_filter: Optional[str] = Form(""),
session_id: Optional[str] = Form(None),
user_id: Optional[str] = Form(None)
):
"""Endpoint for queries with optional file attachments"""
file_content = None
filename = None
if file:
file_content = await file.read()
filename = file.filename
result = process_query_core(
query=query,
reports_filter=reports_filter,
sources_filter=sources_filter,
subtype_filter=subtype_filter,
year_filter=year_filter,
file_content=file_content,
filename=filename,
session_id=session_id,
user_id=user_id,
return_metadata=True
)
return ChatFedOutput(result=result["result"], metadata=result["metadata"])
# LangServe routes (these are the main endpoints)
add_routes(
app,
RunnableLambda(process_query_langserve),
path="/chatfed",
input_type=ChatFedInput,
output_type=ChatFedOutput
)
add_routes(
app,
RunnableLambda(chatui_adapter),
path="/chatfed-ui-stream",
input_type=ChatUIInput,
output_type=str,
enable_feedback_endpoint=True,
enable_public_trace_link_endpoint=True,
)
def run_gradio_server():
demo = create_gradio_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7861,
mcp_server=True,
show_error=True,
share=False,
quiet=True
)
if __name__ == "__main__":
gradio_thread = threading.Thread(target=run_gradio_server, daemon=True)
gradio_thread.start()
logger.info("Gradio MCP server started on port 7861")
host = os.getenv("HOST", "0.0.0.0")
port = int(os.getenv("PORT", "7860"))
logger.info(f"Starting FastAPI server on {host}:{port}")
uvicorn.run(app, host=host, port=port, log_level="info", access_log=True)