Spaces:
Running
Running
import gradio as gr | |
from gradio_client import Client | |
from langgraph.graph import StateGraph, START, END | |
from typing import TypedDict, Optional | |
import io | |
from PIL import Image | |
import os | |
#OPEN QUESTION: SHOULD WE PASS ALL PARAMS FROM THE ORCHESTRATOR TO THE NODES INSTEAD OF SETTING IN EACH MODULE? | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
import configparser | |
import logging | |
import os | |
import ast | |
import re | |
from dotenv import load_dotenv | |
# Local .env file | |
load_dotenv() | |
def getconfig(configfile_path: str): | |
""" | |
Read the config file | |
Params | |
---------------- | |
configfile_path: file path of .cfg file | |
""" | |
config = configparser.ConfigParser() | |
try: | |
config.read_file(open(configfile_path)) | |
return config | |
except: | |
logging.warning("config file not found") | |
def get_auth(provider: str) -> dict: | |
"""Get authentication configuration for different providers""" | |
auth_configs = { | |
"huggingface": {"api_key": os.getenv("HF_TOKEN")}, | |
"qdrant": {"api_key": os.getenv("QDRANT_API_KEY")}, | |
} | |
provider = provider.lower() # Normalize to lowercase | |
if provider not in auth_configs: | |
raise ValueError(f"Unsupported provider: {provider}") | |
auth_config = auth_configs[provider] | |
api_key = auth_config.get("api_key") | |
if not api_key: | |
logging.warning(f"No API key found for provider '{provider}'. Please set the appropriate environment variable.") | |
auth_config["api_key"] = None | |
return auth_config | |
# Define the state schema | |
class GraphState(TypedDict): | |
query: str | |
context: str | |
result: str | |
# Add orchestrator-level parameters (addressing your open question) | |
reports_filter: str | |
sources_filter: str | |
subtype_filter: str | |
year_filter: str | |
# node 2: retriever | |
def retrieve_node(state: GraphState) -> GraphState: | |
client = Client("giz/chatfed_retriever", hf_token=HF_TOKEN) # HF repo name | |
context = client.predict( | |
query=state["query"], | |
reports_filter=state.get("reports_filter", ""), | |
sources_filter=state.get("sources_filter", ""), | |
subtype_filter=state.get("subtype_filter", ""), | |
year_filter=state.get("year_filter", ""), | |
api_name="/retrieve" | |
) | |
return {"context": context} | |
# node 3: generator | |
def generate_node(state: GraphState) -> GraphState: | |
client = Client("giz/chatfed_generator", hf_token=HF_TOKEN) | |
result = client.predict( | |
query=state["query"], | |
context=state["context"], | |
api_name="/generate" | |
) | |
return {"result": result} | |
# build the graph | |
workflow = StateGraph(GraphState) | |
# Add nodes | |
workflow.add_node("retrieve", retrieve_node) | |
workflow.add_node("generate", generate_node) | |
# Add edges | |
workflow.add_edge(START, "retrieve") | |
workflow.add_edge("retrieve", "generate") | |
workflow.add_edge("generate", END) | |
# Compile the graph | |
graph = workflow.compile() | |
# Single tool for processing queries | |
def process_query( | |
query: str, | |
reports_filter: str = "", | |
sources_filter: str = "", | |
subtype_filter: str = "", | |
year_filter: str = "" | |
) -> str: | |
""" | |
Execute the ChatFed orchestration pipeline to process a user query. | |
This function orchestrates a two-step workflow: | |
1. Retrieve relevant context using the ChatFed retriever service with optional filters | |
2. Generate a response using the ChatFed generator service with the retrieved context | |
Args: | |
query (str): The user's input query/question to be processed | |
reports_filter (str, optional): Filter for specific report types. Defaults to "". | |
sources_filter (str, optional): Filter for specific data sources. Defaults to "". | |
subtype_filter (str, optional): Filter for document subtypes. Defaults to "". | |
year_filter (str, optional): Filter for specific years. Defaults to "". | |
Returns: | |
str: The generated response from the ChatFed generator service | |
""" | |
initial_state = { | |
"query": query, | |
"context": "", | |
"result": "", | |
"reports_filter": reports_filter or "", | |
"sources_filter": sources_filter or "", | |
"subtype_filter": subtype_filter or "", | |
"year_filter": year_filter or "" | |
} | |
final_state = graph.invoke(initial_state) | |
return final_state["result"] | |
# Simple testing interface | |
# Guidance for ChatUI - can be removed later. Questionable whether front end even necessary. Maybe nice to show the graph. | |
with gr.Blocks(title="ChatFed Orchestrator") as demo: | |
with gr.Row(): | |
# Left column - Graph visualization | |
with gr.Column(): | |
query_input = gr.Textbox( | |
label="query", | |
lines=2, | |
placeholder="Enter your search query here", | |
info="The query to search for in the vector database" | |
) | |
submit_btn = gr.Button("Submit", variant="primary") | |
# Right column - Interface and documentation | |
with gr.Column(): | |
output = gr.Textbox( | |
label="answer", | |
lines=10, | |
show_copy_button=True | |
) | |
# UI event handler | |
submit_btn.click( | |
fn=process_query, | |
inputs=query_input, | |
outputs=output, | |
api_name="process_query" | |
) | |
if __name__ == "__main__": | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
mcp_server=True, | |
show_error=True | |
) |