timeki's picture
make ask drias asynchronous
e57556f
import re
from typing import Annotated, TypedDict
import duckdb
from geopy.geocoders import Nominatim
import ast
from climateqa.engine.llm import get_llm
from climateqa.engine.talk_to_data.config import DRIAS_TABLES
from climateqa.engine.talk_to_data.plot import PLOTS, Plot
from langchain_core.prompts import ChatPromptTemplate
async def detect_location_with_openai(sentence):
"""
Detects locations in a sentence using OpenAI's API via LangChain.
"""
llm = get_llm()
prompt = f"""
Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
Return the result as a Python list. If no locations are mentioned, return an empty list.
Sentence: "{sentence}"
"""
response = await llm.ainvoke(prompt)
location_list = ast.literal_eval(response.content.strip("```python\n").strip())
if location_list:
return location_list[0]
else:
return ""
class ArrayOutput(TypedDict):
"""Represents the output of a function that returns an array.
This class is used to type-hint functions that return arrays,
ensuring consistent return types across the codebase.
Attributes:
array (str): A syntactically valid Python array string
"""
array: Annotated[str, "Syntactically valid python array."]
async def detect_year_with_openai(sentence: str) -> str:
"""
Detects years in a sentence using OpenAI's API via LangChain.
"""
llm = get_llm()
prompt = """
Extract all years mentioned in the following sentence.
Return the result as a Python list. If no year are mentioned, return an empty list.
Sentence: "{sentence}"
"""
prompt = ChatPromptTemplate.from_template(prompt)
structured_llm = llm.with_structured_output(ArrayOutput)
chain = prompt | structured_llm
response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
years_list = eval(response['array'])
if len(years_list) > 0:
return years_list[0]
else:
return ""
def detectTable(sql_query: str) -> list[str]:
"""Extracts table names from a SQL query.
This function uses regular expressions to find all table names
referenced in a SQL query's FROM clause.
Args:
sql_query (str): The SQL query to analyze
Returns:
list[str]: A list of table names found in the query
Example:
>>> detectTable("SELECT * FROM temperature_data WHERE year > 2000")
['temperature_data']
"""
pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
matches = re.findall(pattern, sql_query)
return matches
def loc2coords(location: str) -> tuple[float, float]:
"""Converts a location name to geographic coordinates.
This function uses the Nominatim geocoding service to convert
a location name (e.g., city name) to its latitude and longitude.
Args:
location (str): The name of the location to geocode
Returns:
tuple[float, float]: A tuple containing (latitude, longitude)
Raises:
AttributeError: If the location cannot be found
"""
geolocator = Nominatim(user_agent="city_to_latlong")
coords = geolocator.geocode(location)
return (coords.latitude, coords.longitude)
def coords2loc(coords: tuple[float, float]) -> str:
"""Converts geographic coordinates to a location name.
This function uses the Nominatim reverse geocoding service to convert
latitude and longitude coordinates to a human-readable location name.
Args:
coords (tuple[float, float]): A tuple containing (latitude, longitude)
Returns:
str: The address of the location, or "Unknown Location" if not found
Example:
>>> coords2loc((48.8566, 2.3522))
'Paris, France'
"""
geolocator = Nominatim(user_agent="coords_to_city")
try:
location = geolocator.reverse(coords)
return location.address
except Exception as e:
print(f"Error: {e}")
return "Unknown Location"
def nearestNeighbourSQL(location: tuple, table: str) -> tuple[str, str]:
long = round(location[1], 3)
lat = round(location[0], 3)
table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
results = duckdb.sql(
f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
).fetchdf()
if len(results) == 0:
return "", ""
# cursor.execute(f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}")
return results['latitude'].iloc[0], results['longitude'].iloc[0]
async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]:
"""Identifies relevant tables for a plot based on user input.
This function uses an LLM to analyze the user's question and the plot
description to determine which tables in the DRIAS database would be
most relevant for generating the requested visualization.
Args:
user_question (str): The user's question about climate data
plot (Plot): The plot configuration object
llm: The language model instance to use for analysis
Returns:
list[str]: A list of table names that are relevant for the plot
Example:
>>> detect_relevant_tables(
... "What will the temperature be like in Paris?",
... indicator_evolution_at_location,
... llm
... )
['mean_annual_temperature', 'mean_summer_temperature']
"""
# Get all table names
table_names_list = DRIAS_TABLES
prompt = (
f"You are helping to build a plot following this description : {plot['description']}."
f"You are given a list of tables and a user question."
f"Based on the description of the plot, which table are appropriate for that kind of plot."
f"Write the 3 most relevant tables to use. Answer only a python list of table name."
f"### List of tables : {table_names_list}"
f"### User question : {user_question}"
f"### List of table name : "
)
table_names = ast.literal_eval(
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
)
return table_names
def replace_coordonates(coords, query, coords_tables):
n = query.count(str(coords[0]))
for i in range(n):
query = query.replace(str(coords[0]), str(coords_tables[i][0]), 1)
query = query.replace(str(coords[1]), str(coords_tables[i][1]), 1)
return query
async def detect_relevant_plots(user_question: str, llm):
plots_description = ""
for plot in PLOTS:
plots_description += "Name: " + plot["name"]
plots_description += " - Description: " + plot["description"] + "\n"
prompt = (
f"You are helping to answer a quesiton with insightful visualizations."
f"You are given an user question and a list of plots with their name and description."
f"Based on the descriptions of the plots, which plot is appropriate to answer to this question."
f"Write the most relevant tables to use. Answer only a python list of plot name."
f"### Descriptions of the plots : {plots_description}"
f"### User question : {user_question}"
f"### Name of the plot : "
)
# prompt = (
# f"You are helping to answer a question with insightful visualizations. "
# f"Given a list of plots with their name and description: "
# f"{plots_description} "
# f"The user question is: {user_question}. "
# f"Choose the most relevant plots to answer the question. "
# f"The answer must be a Python list with the names of the relevant plots, and nothing else. "
# f"Ensure the response is in the exact format: ['PlotName1', 'PlotName2']."
# )
plot_names = ast.literal_eval(
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
)
return plot_names
# Next Version
# class QueryOutput(TypedDict):
# """Generated SQL query."""
# query: Annotated[str, ..., "Syntactically valid SQL query."]
# class PlotlyCodeOutput(TypedDict):
# """Generated Plotly code"""
# code: Annotated[str, ..., "Synatically valid Plotly python code."]
# def write_sql_query(user_input: str, db: SQLDatabase, relevant_tables: list[str], llm):
# """Generate SQL query to fetch information."""
# prompt_params = {
# "dialect": db.dialect,
# "table_info": db.get_table_info(),
# "input": user_input,
# "relevant_tables": relevant_tables,
# "model": "ALADIN63_CNRM-CM5",
# }
# prompt = ChatPromptTemplate.from_template(query_prompt_template)
# structured_llm = llm.with_structured_output(QueryOutput)
# chain = prompt | structured_llm
# result = chain.invoke(prompt_params)
# return result["query"]
# def fetch_data_from_sql_query(db: str, sql_query: str):
# conn = sqlite3.connect(db)
# cursor = conn.cursor()
# cursor.execute(sql_query)
# column_names = [desc[0] for desc in cursor.description]
# values = cursor.fetchall()
# return {"column_names": column_names, "data": values}
# def generate_chart_code(user_input: str, sql_query: list[str], llm):
# """ "Generate plotly python code for the chart based on the sql query and the user question"""
# class PlotlyCodeOutput(TypedDict):
# """Generated Plotly code"""
# code: Annotated[str, ..., "Synatically valid Plotly python code."]
# prompt = ChatPromptTemplate.from_template(plot_prompt_template)
# structured_llm = llm.with_structured_output(PlotlyCodeOutput)
# chain = prompt | structured_llm
# result = chain.invoke({"input": user_input, "sql_query": sql_query})
# return result["code"]