File size: 11,338 Bytes
711bc31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 |
from random import choices
import gradio as gr
from typing import TypedDict
from climateqa.engine.talk_to_data.main import ask_ipcc
from climateqa.engine.talk_to_data.ipcc.config import IPCC_MODELS, IPCC_SCENARIO, IPCC_UI_TEXT
import uuid
class ipccUIElements(TypedDict):
tab: gr.Tab
details_accordion: gr.Accordion
examples_hidden: gr.Textbox
examples: gr.Examples
image_examples: gr.Row
ipcc_direct_question: gr.Textbox
result_text: gr.Textbox
table_names_display: gr.Radio
query_accordion: gr.Accordion
ipcc_sql_query: gr.Textbox
chart_accordion: gr.Accordion
plot_information: gr.Markdown
scenario_selection: gr.Dropdown
ipcc_display: gr.Plot
table_accordion: gr.Accordion
ipcc_table: gr.DataFrame
async def ask_ipcc_query(query: str, index_state: int, user_id: str):
result = await ask_ipcc(query, index_state, user_id)
return result
def hide_outputs():
"""Hide all outputs initially."""
return (
gr.update(visible=True), # Show the result text
gr.update(visible=False), # Hide the query accordion
gr.update(visible=False), # Hide the table accordion
gr.update(visible=False), # Hide the chart accordion
gr.update(visible=False), # Hide table names
)
def show_results(sql_queries_state, dataframes_state, plots_state, table_names):
if not sql_queries_state or not dataframes_state or not plots_state:
# If all results are empty, show "No result"
return (
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
)
else:
# Show the appropriate components with their data
return (
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(choices=table_names, value=table_names[0], visible=True),
)
def show_filter_by_scenario(table_names, index_state, dataframes):
if len(table_names) > 0 and table_names[index_state].startswith("Map"):
df = dataframes[index_state]
scenarios = sorted(df["scenario"].unique())
return gr.update(visible=True, choices=scenarios, value=scenarios[0])
else:
return gr.update(visible=False)
def filter_by_scenario(dataframes, figures, table_names, index_state, scenario):
df = dataframes[index_state]
if not table_names[index_state].startswith("Map"):
return df, figures[index_state](df)
if df.empty:
return df, None
if "scenario" not in df.columns:
return df, figures[index_state](df)
else:
df = df[df["scenario"] == scenario]
if df.empty:
return df, None
figure = figures[index_state](df)
return df, figure
def display_table_names(table_names, index_state):
return [
[name]
for name in table_names
]
def on_table_click(selected_label, table_names, sql_queries, dataframes, plot_informations, plots):
index = table_names.index(selected_label)
figure = plots[index](dataframes[index])
return (
sql_queries[index],
dataframes[index],
figure,
plot_informations[index],
index,
)
def create_ipcc_ui() -> ipccUIElements:
"""Create and return all UI elements for the ipcc tab."""
with gr.Tab("(Beta) Talk to IPCC", elem_id="tab-vanna", id=7) as tab:
with gr.Accordion(label="❓ How to use?", elem_id="details") as details_accordion:
gr.Markdown(IPCC_UI_TEXT)
# Add examples for common questions
examples_hidden = gr.Textbox(visible=False, elem_id="ipcc-examples-hidden")
examples = gr.Examples(
examples=[
["What will the temperature be like in Paris?"],
["What will be the total rainfall in the USA in 2030?"],
["How will the average temperature evolve in China?"],
["What will be the average total precipitation in London ?"]
],
label="Example Questions",
inputs=[examples_hidden],
outputs=[examples_hidden],
)
with gr.Row():
ipcc_direct_question = gr.Textbox(
label="Direct Question",
placeholder="You can write direct question here",
elem_id="direct-question",
interactive=True,
)
with gr.Row(visible=True, elem_id="example-img-container") as image_examples:
gr.Markdown("### Examples of possible visualizations")
with gr.Row():
gr.Image("./front/assets/talk_to_ipcc_france_example.png", label="Total Precipitation in 2030 in France", elem_classes=["example-img"])
gr.Image("./front/assets/talk_to_ipcc_new_york_example.png", label="Yearly Evolution of Mean Temperature in New York (Historical + SSP Scenarios)", elem_classes=["example-img"])
gr.Image("./front/assets/talk_to_ipcc_china_example.png", label="Mean Temperature in 2050 in China", elem_classes=["example-img"])
result_text = gr.Textbox(
label="", elem_id="no-result-label", interactive=False, visible=True
)
with gr.Row():
table_names_display = gr.Radio(
choices=[],
label="Relevant figures created",
interactive=True,
elem_id="table-names",
visible=False
)
with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
ipcc_sql_query = gr.Textbox(
label="", elem_id="sql-query", interactive=False
)
with gr.Accordion(label="Chart", visible=False) as chart_accordion:
with gr.Row():
scenario_selection = gr.Dropdown(
label="Scenario", choices=IPCC_SCENARIO, value=IPCC_SCENARIO[0], interactive=True, visible=False
)
with gr.Accordion(label="Informations about the plot", open=False):
plot_information = gr.Markdown(value = "")
ipcc_display = gr.Plot(elem_id="vanna-plot")
with gr.Accordion(
label="Data used", open=False, visible=False
) as table_accordion:
ipcc_table = gr.DataFrame([], elem_id="vanna-table")
return ipccUIElements(
tab=tab,
details_accordion=details_accordion,
examples_hidden=examples_hidden,
examples=examples,
image_examples=image_examples,
ipcc_direct_question=ipcc_direct_question,
result_text=result_text,
table_names_display=table_names_display,
query_accordion=query_accordion,
ipcc_sql_query=ipcc_sql_query,
chart_accordion=chart_accordion,
plot_information=plot_information,
scenario_selection=scenario_selection,
ipcc_display=ipcc_display,
table_accordion=table_accordion,
ipcc_table=ipcc_table,
)
def setup_ipcc_events(ui_elements: ipccUIElements, share_client=None, user_id=None) -> None:
"""Set up all event handlers for the ipcc tab."""
# Create state variables
sql_queries_state = gr.State([])
dataframes_state = gr.State([])
plots_state = gr.State([])
plot_informations_state = gr.State([])
index_state = gr.State(0)
table_names_list = gr.State([])
user_id = gr.State(user_id)
# Handle direct question submission - trigger the same workflow by setting examples_hidden
ui_elements["ipcc_direct_question"].submit(
lambda x: gr.update(value=x),
inputs=[ui_elements["ipcc_direct_question"]],
outputs=[ui_elements["examples_hidden"]],
)
# Handle example selection
ui_elements["examples_hidden"].change(
lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
inputs=[ui_elements["examples_hidden"]],
outputs=[ui_elements["details_accordion"], ui_elements["ipcc_direct_question"]]
).then(
lambda : gr.update(visible=False),
inputs=None,
outputs=ui_elements["image_examples"]
).then(
hide_outputs,
inputs=None,
outputs=[
ui_elements["result_text"],
ui_elements["query_accordion"],
ui_elements["table_accordion"],
ui_elements["chart_accordion"],
ui_elements["table_names_display"],
]
).then(
ask_ipcc_query,
inputs=[ui_elements["examples_hidden"], index_state, user_id],
outputs=[
ui_elements["ipcc_sql_query"],
ui_elements["ipcc_table"],
ui_elements["ipcc_display"],
ui_elements["plot_information"],
sql_queries_state,
dataframes_state,
plots_state,
plot_informations_state,
index_state,
table_names_list,
ui_elements["result_text"],
],
).then(
show_results,
inputs=[sql_queries_state, dataframes_state, plots_state, table_names_list],
outputs=[
ui_elements["result_text"],
ui_elements["query_accordion"],
ui_elements["table_accordion"],
ui_elements["chart_accordion"],
ui_elements["table_names_display"],
],
).then(
show_filter_by_scenario,
inputs=[table_names_list, index_state, dataframes_state],
outputs=[ui_elements["scenario_selection"]],
).then(
filter_by_scenario,
inputs=[dataframes_state, plots_state, table_names_list, index_state, ui_elements["scenario_selection"]],
outputs=[ui_elements["ipcc_table"], ui_elements["ipcc_display"]],
)
# Handle model selection change
ui_elements["scenario_selection"].change(
filter_by_scenario,
inputs=[dataframes_state, plots_state, table_names_list, index_state, ui_elements["scenario_selection"]],
outputs=[ui_elements["ipcc_table"], ui_elements["ipcc_display"]],
)
# Handle table selection
ui_elements["table_names_display"].change(
fn=on_table_click,
inputs=[ui_elements["table_names_display"], table_names_list, sql_queries_state, dataframes_state, plot_informations_state, plots_state],
outputs=[ui_elements["ipcc_sql_query"], ui_elements["ipcc_table"], ui_elements["ipcc_display"], ui_elements["plot_information"], index_state],
).then(
show_filter_by_scenario,
inputs=[table_names_list, index_state, dataframes_state],
outputs=[ui_elements["scenario_selection"]],
).then(
filter_by_scenario,
inputs=[dataframes_state, plots_state, table_names_list, index_state, ui_elements["scenario_selection"]],
outputs=[ui_elements["ipcc_table"], ui_elements["ipcc_display"]],
)
def create_ipcc_tab(share_client=None, user_id=None):
"""Create the ipcc tab with all its components and event handlers."""
ui_elements = create_ipcc_ui()
setup_ipcc_events(ui_elements, share_client=share_client, user_id=user_id)
|