timeki's picture
talk_to_ipcc (#29)
711bc31 verified
from random import choices
import gradio as gr
from typing import TypedDict
from climateqa.engine.talk_to_data.main import ask_ipcc
from climateqa.engine.talk_to_data.ipcc.config import IPCC_MODELS, IPCC_SCENARIO, IPCC_UI_TEXT
import uuid
class ipccUIElements(TypedDict):
tab: gr.Tab
details_accordion: gr.Accordion
examples_hidden: gr.Textbox
examples: gr.Examples
image_examples: gr.Row
ipcc_direct_question: gr.Textbox
result_text: gr.Textbox
table_names_display: gr.Radio
query_accordion: gr.Accordion
ipcc_sql_query: gr.Textbox
chart_accordion: gr.Accordion
plot_information: gr.Markdown
scenario_selection: gr.Dropdown
ipcc_display: gr.Plot
table_accordion: gr.Accordion
ipcc_table: gr.DataFrame
async def ask_ipcc_query(query: str, index_state: int, user_id: str):
result = await ask_ipcc(query, index_state, user_id)
return result
def hide_outputs():
"""Hide all outputs initially."""
return (
gr.update(visible=True), # Show the result text
gr.update(visible=False), # Hide the query accordion
gr.update(visible=False), # Hide the table accordion
gr.update(visible=False), # Hide the chart accordion
gr.update(visible=False), # Hide table names
)
def show_results(sql_queries_state, dataframes_state, plots_state, table_names):
if not sql_queries_state or not dataframes_state or not plots_state:
# If all results are empty, show "No result"
return (
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
)
else:
# Show the appropriate components with their data
return (
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(choices=table_names, value=table_names[0], visible=True),
)
def show_filter_by_scenario(table_names, index_state, dataframes):
if len(table_names) > 0 and table_names[index_state].startswith("Map"):
df = dataframes[index_state]
scenarios = sorted(df["scenario"].unique())
return gr.update(visible=True, choices=scenarios, value=scenarios[0])
else:
return gr.update(visible=False)
def filter_by_scenario(dataframes, figures, table_names, index_state, scenario):
df = dataframes[index_state]
if not table_names[index_state].startswith("Map"):
return df, figures[index_state](df)
if df.empty:
return df, None
if "scenario" not in df.columns:
return df, figures[index_state](df)
else:
df = df[df["scenario"] == scenario]
if df.empty:
return df, None
figure = figures[index_state](df)
return df, figure
def display_table_names(table_names, index_state):
return [
[name]
for name in table_names
]
def on_table_click(selected_label, table_names, sql_queries, dataframes, plot_informations, plots):
index = table_names.index(selected_label)
figure = plots[index](dataframes[index])
return (
sql_queries[index],
dataframes[index],
figure,
plot_informations[index],
index,
)
def create_ipcc_ui() -> ipccUIElements:
"""Create and return all UI elements for the ipcc tab."""
with gr.Tab("(Beta) Talk to IPCC", elem_id="tab-vanna", id=7) as tab:
with gr.Accordion(label="❓ How to use?", elem_id="details") as details_accordion:
gr.Markdown(IPCC_UI_TEXT)
# Add examples for common questions
examples_hidden = gr.Textbox(visible=False, elem_id="ipcc-examples-hidden")
examples = gr.Examples(
examples=[
["What will the temperature be like in Paris?"],
["What will be the total rainfall in the USA in 2030?"],
["How will the average temperature evolve in China?"],
["What will be the average total precipitation in London ?"]
],
label="Example Questions",
inputs=[examples_hidden],
outputs=[examples_hidden],
)
with gr.Row():
ipcc_direct_question = gr.Textbox(
label="Direct Question",
placeholder="You can write direct question here",
elem_id="direct-question",
interactive=True,
)
with gr.Row(visible=True, elem_id="example-img-container") as image_examples:
gr.Markdown("### Examples of possible visualizations")
with gr.Row():
gr.Image("./front/assets/talk_to_ipcc_france_example.png", label="Total Precipitation in 2030 in France", elem_classes=["example-img"])
gr.Image("./front/assets/talk_to_ipcc_new_york_example.png", label="Yearly Evolution of Mean Temperature in New York (Historical + SSP Scenarios)", elem_classes=["example-img"])
gr.Image("./front/assets/talk_to_ipcc_china_example.png", label="Mean Temperature in 2050 in China", elem_classes=["example-img"])
result_text = gr.Textbox(
label="", elem_id="no-result-label", interactive=False, visible=True
)
with gr.Row():
table_names_display = gr.Radio(
choices=[],
label="Relevant figures created",
interactive=True,
elem_id="table-names",
visible=False
)
with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
ipcc_sql_query = gr.Textbox(
label="", elem_id="sql-query", interactive=False
)
with gr.Accordion(label="Chart", visible=False) as chart_accordion:
with gr.Row():
scenario_selection = gr.Dropdown(
label="Scenario", choices=IPCC_SCENARIO, value=IPCC_SCENARIO[0], interactive=True, visible=False
)
with gr.Accordion(label="Informations about the plot", open=False):
plot_information = gr.Markdown(value = "")
ipcc_display = gr.Plot(elem_id="vanna-plot")
with gr.Accordion(
label="Data used", open=False, visible=False
) as table_accordion:
ipcc_table = gr.DataFrame([], elem_id="vanna-table")
return ipccUIElements(
tab=tab,
details_accordion=details_accordion,
examples_hidden=examples_hidden,
examples=examples,
image_examples=image_examples,
ipcc_direct_question=ipcc_direct_question,
result_text=result_text,
table_names_display=table_names_display,
query_accordion=query_accordion,
ipcc_sql_query=ipcc_sql_query,
chart_accordion=chart_accordion,
plot_information=plot_information,
scenario_selection=scenario_selection,
ipcc_display=ipcc_display,
table_accordion=table_accordion,
ipcc_table=ipcc_table,
)
def setup_ipcc_events(ui_elements: ipccUIElements, share_client=None, user_id=None) -> None:
"""Set up all event handlers for the ipcc tab."""
# Create state variables
sql_queries_state = gr.State([])
dataframes_state = gr.State([])
plots_state = gr.State([])
plot_informations_state = gr.State([])
index_state = gr.State(0)
table_names_list = gr.State([])
user_id = gr.State(user_id)
# Handle direct question submission - trigger the same workflow by setting examples_hidden
ui_elements["ipcc_direct_question"].submit(
lambda x: gr.update(value=x),
inputs=[ui_elements["ipcc_direct_question"]],
outputs=[ui_elements["examples_hidden"]],
)
# Handle example selection
ui_elements["examples_hidden"].change(
lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
inputs=[ui_elements["examples_hidden"]],
outputs=[ui_elements["details_accordion"], ui_elements["ipcc_direct_question"]]
).then(
lambda : gr.update(visible=False),
inputs=None,
outputs=ui_elements["image_examples"]
).then(
hide_outputs,
inputs=None,
outputs=[
ui_elements["result_text"],
ui_elements["query_accordion"],
ui_elements["table_accordion"],
ui_elements["chart_accordion"],
ui_elements["table_names_display"],
]
).then(
ask_ipcc_query,
inputs=[ui_elements["examples_hidden"], index_state, user_id],
outputs=[
ui_elements["ipcc_sql_query"],
ui_elements["ipcc_table"],
ui_elements["ipcc_display"],
ui_elements["plot_information"],
sql_queries_state,
dataframes_state,
plots_state,
plot_informations_state,
index_state,
table_names_list,
ui_elements["result_text"],
],
).then(
show_results,
inputs=[sql_queries_state, dataframes_state, plots_state, table_names_list],
outputs=[
ui_elements["result_text"],
ui_elements["query_accordion"],
ui_elements["table_accordion"],
ui_elements["chart_accordion"],
ui_elements["table_names_display"],
],
).then(
show_filter_by_scenario,
inputs=[table_names_list, index_state, dataframes_state],
outputs=[ui_elements["scenario_selection"]],
).then(
filter_by_scenario,
inputs=[dataframes_state, plots_state, table_names_list, index_state, ui_elements["scenario_selection"]],
outputs=[ui_elements["ipcc_table"], ui_elements["ipcc_display"]],
)
# Handle model selection change
ui_elements["scenario_selection"].change(
filter_by_scenario,
inputs=[dataframes_state, plots_state, table_names_list, index_state, ui_elements["scenario_selection"]],
outputs=[ui_elements["ipcc_table"], ui_elements["ipcc_display"]],
)
# Handle table selection
ui_elements["table_names_display"].change(
fn=on_table_click,
inputs=[ui_elements["table_names_display"], table_names_list, sql_queries_state, dataframes_state, plot_informations_state, plots_state],
outputs=[ui_elements["ipcc_sql_query"], ui_elements["ipcc_table"], ui_elements["ipcc_display"], ui_elements["plot_information"], index_state],
).then(
show_filter_by_scenario,
inputs=[table_names_list, index_state, dataframes_state],
outputs=[ui_elements["scenario_selection"]],
).then(
filter_by_scenario,
inputs=[dataframes_state, plots_state, table_names_list, index_state, ui_elements["scenario_selection"]],
outputs=[ui_elements["ipcc_table"], ui_elements["ipcc_display"]],
)
def create_ipcc_tab(share_client=None, user_id=None):
"""Create the ipcc tab with all its components and event handlers."""
ui_elements = create_ipcc_ui()
setup_ipcc_events(ui_elements, share_client=share_client, user_id=user_id)