timeki's picture
talk_to_ipcc (#29)
711bc31 verified
import os
import geojson
from math import cos, radians
from typing import Callable
import pandas as pd
from plotly.graph_objects import Figure
import plotly.graph_objects as go
from climateqa.engine.talk_to_data.drias.plot_informations import distribution_of_indicator_for_given_year_informations, indicator_evolution_informations, indicator_number_of_days_per_year_informations, map_of_france_of_indicator_for_given_year_informations
from climateqa.engine.talk_to_data.objects.plot import Plot
from climateqa.engine.talk_to_data.drias.queries import (
indicator_for_given_year_query,
indicator_per_year_at_location_query,
)
from climateqa.engine.talk_to_data.drias.config import DRIAS_INDICATOR_TO_COLORSCALE, DRIAS_INDICATOR_TO_UNIT
def generate_geojson_polygons(latitudes: list[float], longitudes: list[float], indicators: list[float]) -> geojson.FeatureCollection:
side_km = 8
delta_lat = side_km / 111
features = []
for idx, (lat, lon, val) in enumerate(zip(latitudes, longitudes, indicators)):
delta_lon = side_km / (111 * cos(radians(lat)))
half_lat = delta_lat / 2
half_lon = delta_lon / 2
features.append(geojson.Feature(
geometry=geojson.Polygon([[
[lon - half_lon, lat - half_lat],
[lon + half_lon, lat - half_lat],
[lon + half_lon, lat + half_lat],
[lon - half_lon, lat + half_lat],
[lon - half_lon, lat - half_lat]
]]),
properties={"value": val},
id=str(idx)
))
return geojson.FeatureCollection(features)
def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
"""Generates a function to plot indicator evolution over time at a location.
This function creates a line plot showing how a climate indicator changes
over time at a specific location. It handles temperature, precipitation,
and other climate indicators.
Args:
params (dict): Dictionary containing:
- indicator_column (str): The column name for the indicator
- location (str): The location to plot
- model (str): The climate model to use
Returns:
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
Example:
>>> plot_func = plot_indicator_evolution_at_location({
... 'indicator_column': 'mean_temperature',
... 'location': 'Paris',
... 'model': 'ALL'
... })
>>> fig = plot_func(df)
"""
indicator = params["indicator_column"]
location = params["location"]
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
def plot_data(df: pd.DataFrame) -> Figure:
"""Generates the actual plot from the data.
Args:
df (pd.DataFrame): DataFrame containing the data to plot
Returns:
Figure: A plotly Figure object showing the indicator evolution
"""
fig = go.Figure()
if df['model'].nunique() != 1:
df_avg = df.groupby("year", as_index=False)[indicator].mean()
# Transform to list to avoid pandas encoding
indicators = df_avg[indicator].astype(float).tolist()
years = df_avg["year"].astype(int).tolist()
# Compute the 10-year rolling average
rolling_window = 10
sliding_averages = (
df_avg[indicator]
.rolling(window=rolling_window, min_periods=rolling_window)
.mean()
.astype(float)
.tolist()
)
model_label = "Model Average"
# Only add rolling average if we have enough data points
if len([x for x in sliding_averages if pd.notna(x)]) > 0:
# Sliding average dashed line
fig.add_scatter(
x=years,
y=sliding_averages,
mode="lines",
name="10 years rolling average",
line=dict(dash="dash"),
marker=dict(color="#d62728"),
hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
)
else:
df_model = df
# Transform to list to avoid pandas encoding
indicators = df_model[indicator].astype(float).tolist()
years = df_model["year"].astype(int).tolist()
# Compute the 10-year rolling average
rolling_window = 10
sliding_averages = (
df_model[indicator]
.rolling(window=rolling_window, min_periods=rolling_window)
.mean()
.astype(float)
.tolist()
)
model_label = f"Model : {df['model'].unique()[0]}"
# Only add rolling average if we have enough data points
if len([x for x in sliding_averages if pd.notna(x)]) > 0:
# Sliding average dashed line
fig.add_scatter(
x=years,
y=sliding_averages,
mode="lines",
name="10 years rolling average",
line=dict(dash="dash"),
marker=dict(color="#d62728"),
hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
)
# Indicator per year plot
fig.add_scatter(
x=years,
y=indicators,
name=f"Yearly {indicator_label}",
mode="lines",
marker=dict(color="#1f77b4"),
hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
)
fig.update_layout(
title=f"Evolution of {indicator_label} in {location} ({model_label})",
xaxis_title="Year",
yaxis_title=f"{indicator_label} ({unit})",
template="plotly_white",
height=900,
)
return fig
return plot_data
indicator_evolution_at_location: Plot = {
"name": "Indicator evolution at location",
"description": "Plot an evolution of the indicator at a certain location",
"params": ["indicator_column", "location", "model"],
"plot_function": plot_indicator_evolution_at_location,
"sql_query": indicator_per_year_at_location_query,
"plot_information": indicator_evolution_informations,
'short_name': 'Evolution'
}
def plot_indicator_number_of_days_per_year_at_location(
params: dict,
) -> Callable[..., Figure]:
"""Generates a function to plot the number of days per year for an indicator.
This function creates a bar chart showing the frequency of certain climate
events (like days above a temperature threshold) per year at a specific location.
Args:
params (dict): Dictionary containing:
- indicator_column (str): The column name for the indicator
- location (str): The location to plot
- model (str): The climate model to use
Returns:
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
"""
indicator = params["indicator_column"]
location = params["location"]
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
def plot_data(df: pd.DataFrame) -> Figure:
"""Generate the figure thanks to the dataframe
Args:
df (pd.DataFrame): pandas dataframe with the required data
Returns:
Figure: Plotly figure
"""
fig = go.Figure()
if df['model'].nunique() != 1:
df_avg = df.groupby("year", as_index=False)[indicator].mean()
# Transform to list to avoid pandas encoding
indicators = df_avg[indicator].astype(float).tolist()
years = df_avg["year"].astype(int).tolist()
model_label = "Model Average"
else:
df_model = df
# Transform to list to avoid pandas encoding
indicators = df_model[indicator].astype(float).tolist()
years = df_model["year"].astype(int).tolist()
model_label = f"Model : {df['model'].unique()[0]}"
# Bar plot
fig.add_trace(
go.Bar(
x=years,
y=indicators,
width=0.5,
marker=dict(color="#1f77b4"),
hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
)
)
fig.update_layout(
title=f"{indicator_label} in {location} ({model_label})",
xaxis_title="Year",
yaxis_title=f"{indicator_label} ({unit})",
yaxis=dict(range=[0, max(indicators)]),
bargap=0.5,
height=900,
template="plotly_white",
)
return fig
return plot_data
indicator_number_of_days_per_year_at_location: Plot = {
"name": "Indicator number of days per year at location",
"description": "Plot a barchart of the number of days per year of a certain indicator at a certain location. It is appropriate for frequency indicator.",
"params": ["indicator_column", "location", "model"],
"plot_function": plot_indicator_number_of_days_per_year_at_location,
"sql_query": indicator_per_year_at_location_query,
"plot_information": indicator_number_of_days_per_year_informations,
"short_name": "Yearly Frequency",
}
def plot_distribution_of_indicator_for_given_year(
params: dict,
) -> Callable[..., Figure]:
"""Generates a function to plot the distribution of an indicator for a year.
This function creates a histogram showing the distribution of a climate
indicator across different locations for a specific year.
Args:
params (dict): Dictionary containing:
- indicator_column (str): The column name for the indicator
- year (str): The year to plot
- model (str): The climate model to use
Returns:
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
"""
indicator = params["indicator_column"]
year = params["year"]
if year is None:
year = 2030
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
def plot_data(df: pd.DataFrame) -> Figure:
"""Generate the figure thanks to the dataframe
Args:
df (pd.DataFrame): pandas dataframe with the required data
Returns:
Figure: Plotly figure
"""
fig = go.Figure()
if df['model'].nunique() != 1:
df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
indicator
].mean()
# Transform to list to avoid pandas encoding
indicators = df_avg[indicator].astype(float).tolist()
model_label = "Model Average"
else:
df_model = df
# Transform to list to avoid pandas encoding
indicators = df_model[indicator].astype(float).tolist()
model_label = f"Model : {df['model'].unique()[0]}"
fig.add_trace(
go.Histogram(
x=indicators,
opacity=0.8,
histnorm="percent",
marker=dict(color="#1f77b4"),
hovertemplate=f"{indicator_label}: %{{x:.2f}} {unit}<br>Frequency: %{{y:.2f}}%<extra></extra>"
)
)
fig.update_layout(
title=f"Distribution of {indicator_label} in {year} ({model_label})",
xaxis_title=f"{indicator_label} ({unit})",
yaxis_title="Frequency (%)",
plot_bgcolor="rgba(0, 0, 0, 0)",
showlegend=False,
height=900,
)
return fig
return plot_data
distribution_of_indicator_for_given_year: Plot = {
"name": "Distribution of an indicator for a given year",
"description": "Plot an histogram of the distribution for a given year of the values of an indicator",
"params": ["indicator_column", "model", "year"],
"plot_function": plot_distribution_of_indicator_for_given_year,
"sql_query": indicator_for_given_year_query,
"plot_information": distribution_of_indicator_for_given_year_informations,
'short_name': 'Distribution'
}
def plot_map_of_france_of_indicator_for_given_year(
params: dict,
) -> Callable[..., Figure]:
"""Generates a function to plot a map of France for an indicator.
This function creates a choropleth map of France showing the spatial
distribution of a climate indicator for a specific year.
Args:
params (dict): Dictionary containing:
- indicator_column (str): The column name for the indicator
- year (str): The year to plot
- model (str): The climate model to use
Returns:
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
"""
indicator = params["indicator_column"]
year = params["year"]
if year is None:
year = 2030
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
def plot_data(df: pd.DataFrame) -> Figure:
fig = go.Figure()
if df['model'].nunique() != 1:
df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
indicator
].mean()
indicators = df_avg[indicator].astype(float).tolist()
latitudes = df_avg["latitude"].astype(float).tolist()
longitudes = df_avg["longitude"].astype(float).tolist()
model_label = "Model Average"
else:
df_model = df
# Transform to list to avoid pandas encoding
indicators = df_model[indicator].astype(float).tolist()
latitudes = df_model["latitude"].astype(float).tolist()
longitudes = df_model["longitude"].astype(float).tolist()
model_label = f"Model : {df['model'].unique()[0]}"
geojson_data = generate_geojson_polygons(latitudes, longitudes, indicators)
fig = go.Figure(go.Choroplethmapbox(
geojson=geojson_data,
locations=[str(i) for i in range(len(indicators))],
featureidkey="id",
z=indicators,
colorscale=DRIAS_INDICATOR_TO_COLORSCALE[indicator],
zmin=min(indicators),
zmax=max(indicators),
marker_opacity=0.7,
marker_line_width=0,
colorbar_title=f"{indicator_label} ({unit})",
text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
hoverinfo="text"
))
fig.update_layout(
mapbox_style="open-street-map", # Use OpenStreetMap
mapbox_zoom=5,
height=900,
mapbox_center={"lat": 46.6, "lon": 2.0},
coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"), # Add legend
title=f"{indicator_label} in {year} in France ({model_label}) " # Title
)
return fig
return plot_data
map_of_france_of_indicator_for_given_year: Plot = {
"name": "Map of France of an indicator for a given year",
"description": "Heatmap on the map of France of the values of an indicator for a given year",
"params": ["indicator_column", "year", "model"],
"plot_function": plot_map_of_france_of_indicator_for_given_year,
"sql_query": indicator_for_given_year_query,
"plot_information": map_of_france_of_indicator_for_given_year_informations,
'short_name': 'Map of France'
}
DRIAS_PLOTS = [
indicator_evolution_at_location,
indicator_number_of_days_per_year_at_location,
distribution_of_indicator_for_given_year,
map_of_france_of_indicator_for_given_year,
]