import os import duckdb import gradio as gr import matplotlib.pyplot as plt from transformers import HfEngine, ReactCodeAgent from transformers.agents import Tool from langsmith import traceable from langchain import hub # Height of the Tabs Text Area TAB_LINES = 8 #----------CONNECT TO DATABASE---------- md_token = os.getenv('MD_TOKEN') conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True) #--------------------------------------- #-------LOAD HUGGINGFACE MODEL------- models = ["Qwen/Qwen2.5-72B-Instruct","meta-llama/Meta-Llama-3-70B-Instruct", "meta-llama/Llama-3.1-70B-Instruct"] model_loaded = False for model in models: try: llm_engine = HfEngine(model=model) info = llm_engine.client.get_endpoint_info() model_loaded = True break except Exception as e: print(f"Error for model {model}: {e}") continue if not model_loaded: gr.Warning(f"❌ None of the model form {models} are available. {e}") #--------------------------------------- #-----LOAD PROMPT FROM LANCHAIN HUB----- prompt = hub.pull("viz-prompt") #------------------------------------- #--------------ALL UTILS---------------- def get_schemas(): schemas = conn.execute(""" SELECT DISTINCT schema_name FROM information_schema.schemata WHERE schema_name NOT IN ('information_schema', 'pg_catalog') """).fetchall() return [item[0] for item in schemas] # Get Tables def get_tables(schema_name): tables = conn.execute(f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'").fetchall() return [table[0] for table in tables] # Update Tables def update_tables(schema_name): tables = get_tables(schema_name) return gr.update(choices=tables) # Get Schema def get_table_schema(table): result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df() ddl_create = result.iloc[0,0] parent_database = result.iloc[0,1] schema_name = result.iloc[0,2] full_path = f"{parent_database}.{schema_name}.{table}" if schema_name != "main": old_path = f"{schema_name}.{table}" else: old_path = table ddl_create = ddl_create.replace(old_path, full_path) return ddl_create, full_path class SQLExecutorTool(Tool): name = "sql_engine" inputs = { "query": { "type": "text", "description": f"The query to perform. This should be correct DuckDB SQL.", } } description = """Allows you to perform SQL queries on the table. Returns a pandas dataframe representation of the result.""" output_type = "pandas.core.frame.DataFrame" def forward(self, query: str) -> str: output_df = conn.sql(query).df() return output_df tool = SQLExecutorTool() def process_outputs(output) : return { 'sql': output.get('sql', None), 'code': output.get('code', None) } @traceable(process_outputs=process_outputs) def get_visualization(question, schema, table_name): agent = ReactCodeAgent(tools=[tool], llm_engine=llm_engine, add_base_tools=True, additional_authorized_imports=['matplotlib.pyplot', 'pandas', 'plotly.express', 'seaborn'], max_iterations=10) results = agent.run( task= prompt.format(question=question, schema=schema, table_name=table_name) ) return results #--------------------------------------- def main(table, text_query): # Empty Fig fig, ax = plt.subplots() ax.set_axis_off() schema, table_name = get_table_schema(table) try: output = get_visualization(question=text_query, schema=schema, table_name=table_name) fig = output.get('fig', None) generated_sql = output.get('sql', None) data = output.get('data', None) except Exception as e: gr.Warning(f"❌ Unable to generate the visualization. {e}") return fig, generated_sql, data custom_css = """ .gradio-container { background-color: #f0f4f8; } .logo { max-width: 200px; margin: 20px auto; display: block; } .gr-button { background-color: #4a90e2 !important; } .gr-button:hover { background-color: #3a7bc8 !important; } """ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo: gr.Image("logo.png", label=None, show_label=False, container=False, height=100) gr.Markdown("""
DataViz Agent
Visualize SQL queries based on a given text for the dataset.
""") with gr.Row(): with gr.Column(scale=1): schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True) tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None) with gr.Column(scale=2): query_input = gr.Textbox(lines=3, label="Text Query", placeholder="Enter your text query here...") with gr.Row(): with gr.Column(scale=7): pass with gr.Column(scale=1): generate_query_button = gr.Button("Run Query", variant="primary") with gr.Tabs(): with gr.Tab("Plot"): result_plot = gr.Plot() with gr.Tab("SQL"): generated_sql = gr.Textbox(lines=TAB_LINES, label="Generated SQL", value="", interactive=False, autoscroll=False) with gr.Tab("Data"): data = gr.Dataframe(label="Data", interactive=False) schema_dropdown.change(update_tables, inputs=schema_dropdown, outputs=tables_dropdown) generate_query_button.click(main, inputs=[tables_dropdown, query_input], outputs=[result_plot, generated_sql, data]) if __name__ == "__main__": demo.launch(debug=True)