ppsingh's picture
Update app.py
8dc1e51 verified
import gradio as gr
from gradio_client import Client
from langgraph.graph import StateGraph, START, END
from typing import TypedDict, Optional
import io
from PIL import Image
import os
#OPEN QUESTION: SHOULD WE PASS ALL PARAMS FROM THE ORCHESTRATOR TO THE NODES INSTEAD OF SETTING IN EACH MODULE?
HF_TOKEN = os.environ.get("HF_TOKEN")
import configparser
import logging
import os
import ast
import re
from dotenv import load_dotenv
# Local .env file
load_dotenv()
def getconfig(configfile_path: str):
"""
Read the config file
Params
----------------
configfile_path: file path of .cfg file
"""
config = configparser.ConfigParser()
try:
config.read_file(open(configfile_path))
return config
except:
logging.warning("config file not found")
def get_auth(provider: str) -> dict:
"""Get authentication configuration for different providers"""
auth_configs = {
"huggingface": {"api_key": os.getenv("HF_TOKEN")},
"qdrant": {"api_key": os.getenv("QDRANT_API_KEY")},
}
provider = provider.lower() # Normalize to lowercase
if provider not in auth_configs:
raise ValueError(f"Unsupported provider: {provider}")
auth_config = auth_configs[provider]
api_key = auth_config.get("api_key")
if not api_key:
logging.warning(f"No API key found for provider '{provider}'. Please set the appropriate environment variable.")
auth_config["api_key"] = None
return auth_config
# Define the state schema
class GraphState(TypedDict):
query: str
context: str
result: str
# Add orchestrator-level parameters (addressing your open question)
reports_filter: str
sources_filter: str
subtype_filter: str
year_filter: str
# node 2: retriever
def retrieve_node(state: GraphState) -> GraphState:
client = Client("giz/chatfed_retriever", hf_token=HF_TOKEN) # HF repo name
context = client.predict(
query=state["query"],
reports_filter=state.get("reports_filter", ""),
sources_filter=state.get("sources_filter", ""),
subtype_filter=state.get("subtype_filter", ""),
year_filter=state.get("year_filter", ""),
api_name="/retrieve"
)
return {"context": context}
# node 3: generator
def generate_node(state: GraphState) -> GraphState:
client = Client("giz/chatfed_generator", hf_token=HF_TOKEN)
result = client.predict(
query=state["query"],
context=state["context"],
api_name="/generate"
)
return {"result": result}
# build the graph
workflow = StateGraph(GraphState)
# Add nodes
workflow.add_node("retrieve", retrieve_node)
workflow.add_node("generate", generate_node)
# Add edges
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)
# Compile the graph
graph = workflow.compile()
# Single tool for processing queries
def process_query(
query: str,
reports_filter: str = "",
sources_filter: str = "",
subtype_filter: str = "",
year_filter: str = ""
) -> str:
"""
Execute the ChatFed orchestration pipeline to process a user query.
This function orchestrates a two-step workflow:
1. Retrieve relevant context using the ChatFed retriever service with optional filters
2. Generate a response using the ChatFed generator service with the retrieved context
Args:
query (str): The user's input query/question to be processed
reports_filter (str, optional): Filter for specific report types. Defaults to "".
sources_filter (str, optional): Filter for specific data sources. Defaults to "".
subtype_filter (str, optional): Filter for document subtypes. Defaults to "".
year_filter (str, optional): Filter for specific years. Defaults to "".
Returns:
str: The generated response from the ChatFed generator service
"""
initial_state = {
"query": query,
"context": "",
"result": "",
"reports_filter": reports_filter or "",
"sources_filter": sources_filter or "",
"subtype_filter": subtype_filter or "",
"year_filter": year_filter or ""
}
final_state = graph.invoke(initial_state)
return final_state["result"]
# Simple testing interface
# Guidance for ChatUI - can be removed later. Questionable whether front end even necessary. Maybe nice to show the graph.
with gr.Blocks(title="ChatFed Orchestrator") as demo:
with gr.Row():
# Left column - Graph visualization
with gr.Column():
query_input = gr.Textbox(
label="query",
lines=2,
placeholder="Enter your search query here",
info="The query to search for in the vector database"
)
submit_btn = gr.Button("Submit", variant="primary")
# Right column - Interface and documentation
with gr.Column():
output = gr.Textbox(
label="answer",
lines=10,
show_copy_button=True
)
# UI event handler
submit_btn.click(
fn=process_query,
inputs=query_input,
outputs=output,
api_name="process_query"
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
mcp_server=True,
show_error=True
)