|
import re |
|
from typing import Annotated, TypedDict |
|
import duckdb |
|
from geopy.geocoders import Nominatim |
|
import ast |
|
from climateqa.engine.llm import get_llm |
|
from climateqa.engine.talk_to_data.config import DRIAS_TABLES |
|
from climateqa.engine.talk_to_data.plot import PLOTS, Plot |
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
|
|
|
async def detect_location_with_openai(sentence): |
|
""" |
|
Detects locations in a sentence using OpenAI's API via LangChain. |
|
""" |
|
llm = get_llm() |
|
|
|
prompt = f""" |
|
Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence. |
|
Return the result as a Python list. If no locations are mentioned, return an empty list. |
|
|
|
Sentence: "{sentence}" |
|
""" |
|
|
|
response = await llm.ainvoke(prompt) |
|
location_list = ast.literal_eval(response.content.strip("```python\n").strip()) |
|
if location_list: |
|
return location_list[0] |
|
else: |
|
return "" |
|
|
|
class ArrayOutput(TypedDict): |
|
"""Represents the output of a function that returns an array. |
|
|
|
This class is used to type-hint functions that return arrays, |
|
ensuring consistent return types across the codebase. |
|
|
|
Attributes: |
|
array (str): A syntactically valid Python array string |
|
""" |
|
array: Annotated[str, "Syntactically valid python array."] |
|
|
|
async def detect_year_with_openai(sentence: str) -> str: |
|
""" |
|
Detects years in a sentence using OpenAI's API via LangChain. |
|
""" |
|
llm = get_llm() |
|
|
|
prompt = """ |
|
Extract all years mentioned in the following sentence. |
|
Return the result as a Python list. If no year are mentioned, return an empty list. |
|
|
|
Sentence: "{sentence}" |
|
""" |
|
|
|
prompt = ChatPromptTemplate.from_template(prompt) |
|
structured_llm = llm.with_structured_output(ArrayOutput) |
|
chain = prompt | structured_llm |
|
response: ArrayOutput = await chain.ainvoke({"sentence": sentence}) |
|
years_list = eval(response['array']) |
|
if len(years_list) > 0: |
|
return years_list[0] |
|
else: |
|
return "" |
|
|
|
|
|
def detectTable(sql_query: str) -> list[str]: |
|
"""Extracts table names from a SQL query. |
|
|
|
This function uses regular expressions to find all table names |
|
referenced in a SQL query's FROM clause. |
|
|
|
Args: |
|
sql_query (str): The SQL query to analyze |
|
|
|
Returns: |
|
list[str]: A list of table names found in the query |
|
|
|
Example: |
|
>>> detectTable("SELECT * FROM temperature_data WHERE year > 2000") |
|
['temperature_data'] |
|
""" |
|
pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)' |
|
matches = re.findall(pattern, sql_query) |
|
return matches |
|
|
|
|
|
def loc2coords(location: str) -> tuple[float, float]: |
|
"""Converts a location name to geographic coordinates. |
|
|
|
This function uses the Nominatim geocoding service to convert |
|
a location name (e.g., city name) to its latitude and longitude. |
|
|
|
Args: |
|
location (str): The name of the location to geocode |
|
|
|
Returns: |
|
tuple[float, float]: A tuple containing (latitude, longitude) |
|
|
|
Raises: |
|
AttributeError: If the location cannot be found |
|
""" |
|
geolocator = Nominatim(user_agent="city_to_latlong") |
|
coords = geolocator.geocode(location) |
|
return (coords.latitude, coords.longitude) |
|
|
|
|
|
def coords2loc(coords: tuple[float, float]) -> str: |
|
"""Converts geographic coordinates to a location name. |
|
|
|
This function uses the Nominatim reverse geocoding service to convert |
|
latitude and longitude coordinates to a human-readable location name. |
|
|
|
Args: |
|
coords (tuple[float, float]): A tuple containing (latitude, longitude) |
|
|
|
Returns: |
|
str: The address of the location, or "Unknown Location" if not found |
|
|
|
Example: |
|
>>> coords2loc((48.8566, 2.3522)) |
|
'Paris, France' |
|
""" |
|
geolocator = Nominatim(user_agent="coords_to_city") |
|
try: |
|
location = geolocator.reverse(coords) |
|
return location.address |
|
except Exception as e: |
|
print(f"Error: {e}") |
|
return "Unknown Location" |
|
|
|
|
|
def nearestNeighbourSQL(location: tuple, table: str) -> tuple[str, str]: |
|
long = round(location[1], 3) |
|
lat = round(location[0], 3) |
|
|
|
table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'" |
|
|
|
results = duckdb.sql( |
|
f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}" |
|
).fetchdf() |
|
|
|
if len(results) == 0: |
|
return "", "" |
|
|
|
return results['latitude'].iloc[0], results['longitude'].iloc[0] |
|
|
|
|
|
async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]: |
|
"""Identifies relevant tables for a plot based on user input. |
|
|
|
This function uses an LLM to analyze the user's question and the plot |
|
description to determine which tables in the DRIAS database would be |
|
most relevant for generating the requested visualization. |
|
|
|
Args: |
|
user_question (str): The user's question about climate data |
|
plot (Plot): The plot configuration object |
|
llm: The language model instance to use for analysis |
|
|
|
Returns: |
|
list[str]: A list of table names that are relevant for the plot |
|
|
|
Example: |
|
>>> detect_relevant_tables( |
|
... "What will the temperature be like in Paris?", |
|
... indicator_evolution_at_location, |
|
... llm |
|
... ) |
|
['mean_annual_temperature', 'mean_summer_temperature'] |
|
""" |
|
|
|
table_names_list = DRIAS_TABLES |
|
|
|
prompt = ( |
|
f"You are helping to build a plot following this description : {plot['description']}." |
|
f"You are given a list of tables and a user question." |
|
f"Based on the description of the plot, which table are appropriate for that kind of plot." |
|
f"Write the 3 most relevant tables to use. Answer only a python list of table name." |
|
f"### List of tables : {table_names_list}" |
|
f"### User question : {user_question}" |
|
f"### List of table name : " |
|
) |
|
|
|
table_names = ast.literal_eval( |
|
(await llm.ainvoke(prompt)).content.strip("```python\n").strip() |
|
) |
|
return table_names |
|
|
|
|
|
def replace_coordonates(coords, query, coords_tables): |
|
n = query.count(str(coords[0])) |
|
|
|
for i in range(n): |
|
query = query.replace(str(coords[0]), str(coords_tables[i][0]), 1) |
|
query = query.replace(str(coords[1]), str(coords_tables[i][1]), 1) |
|
return query |
|
|
|
|
|
async def detect_relevant_plots(user_question: str, llm): |
|
plots_description = "" |
|
for plot in PLOTS: |
|
plots_description += "Name: " + plot["name"] |
|
plots_description += " - Description: " + plot["description"] + "\n" |
|
|
|
prompt = ( |
|
f"You are helping to answer a quesiton with insightful visualizations." |
|
f"You are given an user question and a list of plots with their name and description." |
|
f"Based on the descriptions of the plots, which plot is appropriate to answer to this question." |
|
f"Write the most relevant tables to use. Answer only a python list of plot name." |
|
f"### Descriptions of the plots : {plots_description}" |
|
f"### User question : {user_question}" |
|
f"### Name of the plot : " |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plot_names = ast.literal_eval( |
|
(await llm.ainvoke(prompt)).content.strip("```python\n").strip() |
|
) |
|
return plot_names |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|