import gradio as gr from gradio_client import Client from langgraph.graph import StateGraph, START, END from typing import TypedDict, Optional import io from PIL import Image import os #OPEN QUESTION: SHOULD WE PASS ALL PARAMS FROM THE ORCHESTRATOR TO THE NODES INSTEAD OF SETTING IN EACH MODULE? HF_TOKEN = os.environ.get("HF_TOKEN") import configparser import logging import os import ast import re from dotenv import load_dotenv # Local .env file load_dotenv() def getconfig(configfile_path: str): """ Read the config file Params ---------------- configfile_path: file path of .cfg file """ config = configparser.ConfigParser() try: config.read_file(open(configfile_path)) return config except: logging.warning("config file not found") def get_auth(provider: str) -> dict: """Get authentication configuration for different providers""" auth_configs = { "huggingface": {"api_key": os.getenv("HF_TOKEN")}, "qdrant": {"api_key": os.getenv("QDRANT_API_KEY")}, } provider = provider.lower() # Normalize to lowercase if provider not in auth_configs: raise ValueError(f"Unsupported provider: {provider}") auth_config = auth_configs[provider] api_key = auth_config.get("api_key") if not api_key: logging.warning(f"No API key found for provider '{provider}'. Please set the appropriate environment variable.") auth_config["api_key"] = None return auth_config # Define the state schema class GraphState(TypedDict): query: str context: str result: str # Add orchestrator-level parameters (addressing your open question) reports_filter: str sources_filter: str subtype_filter: str year_filter: str # node 2: retriever def retrieve_node(state: GraphState) -> GraphState: client = Client("giz/chatfed_retriever", hf_token=HF_TOKEN) # HF repo name context = client.predict( query=state["query"], reports_filter=state.get("reports_filter", ""), sources_filter=state.get("sources_filter", ""), subtype_filter=state.get("subtype_filter", ""), year_filter=state.get("year_filter", ""), api_name="/retrieve" ) return {"context": context} # node 3: generator def generate_node(state: GraphState) -> GraphState: client = Client("giz/chatfed_generator", hf_token=HF_TOKEN) result = client.predict( query=state["query"], context=state["context"], api_name="/generate" ) return {"result": result} # build the graph workflow = StateGraph(GraphState) # Add nodes workflow.add_node("retrieve", retrieve_node) workflow.add_node("generate", generate_node) # Add edges workflow.add_edge(START, "retrieve") workflow.add_edge("retrieve", "generate") workflow.add_edge("generate", END) # Compile the graph graph = workflow.compile() # Single tool for processing queries def process_query( query: str, reports_filter: str = "", sources_filter: str = "", subtype_filter: str = "", year_filter: str = "" ) -> str: """ Execute the ChatFed orchestration pipeline to process a user query. This function orchestrates a two-step workflow: 1. Retrieve relevant context using the ChatFed retriever service with optional filters 2. Generate a response using the ChatFed generator service with the retrieved context Args: query (str): The user's input query/question to be processed reports_filter (str, optional): Filter for specific report types. Defaults to "". sources_filter (str, optional): Filter for specific data sources. Defaults to "". subtype_filter (str, optional): Filter for document subtypes. Defaults to "". year_filter (str, optional): Filter for specific years. Defaults to "". Returns: str: The generated response from the ChatFed generator service """ initial_state = { "query": query, "context": "", "result": "", "reports_filter": reports_filter or "", "sources_filter": sources_filter or "", "subtype_filter": subtype_filter or "", "year_filter": year_filter or "" } final_state = graph.invoke(initial_state) return final_state["result"] # Simple testing interface # Guidance for ChatUI - can be removed later. Questionable whether front end even necessary. Maybe nice to show the graph. with gr.Blocks(title="ChatFed Orchestrator") as demo: with gr.Row(): # Left column - Graph visualization with gr.Column(): query_input = gr.Textbox( label="query", lines=2, placeholder="Enter your search query here", info="The query to search for in the vector database" ) submit_btn = gr.Button("Submit", variant="primary") # Right column - Interface and documentation with gr.Column(): output = gr.Textbox( label="answer", lines=10, show_copy_button=True ) # UI event handler submit_btn.click( fn=process_query, inputs=query_input, outputs=output, api_name="process_query" ) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, mcp_server=True, show_error=True )