timeki's picture
switch_vectorstore_to_azure_ai_search (#30)
ac49be7 verified
from typing import Any, Literal, Optional, cast
import ast
from langchain_core.prompts import ChatPromptTemplate
from geopy.geocoders import Nominatim
from climateqa.engine.llm import get_llm
import duckdb
import os
from climateqa.engine.talk_to_data.config import DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH, IPCC_COORDINATES_PATH
from climateqa.engine.talk_to_data.objects.llm_outputs import ArrayOutput
from climateqa.engine.talk_to_data.objects.location import Location
from climateqa.engine.talk_to_data.objects.plot import Plot
from climateqa.engine.talk_to_data.objects.states import State
import calendar
async def detect_location_with_openai(sentence: str) -> str:
"""
Detects locations in a sentence using OpenAI's API via LangChain.
"""
llm = get_llm()
prompt = f"""
Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
Return the result as a Python list. If no locations are mentioned, return an empty list.
Sentence: "{sentence}"
"""
response = await llm.ainvoke(prompt)
location_list = ast.literal_eval(response.content.strip("```python\n").strip())
if location_list:
return location_list[0]
else:
return ""
def loc_to_coords(location: str) -> tuple[float, float]:
"""Converts a location name to geographic coordinates.
This function uses the Nominatim geocoding service to convert
a location name (e.g., city name) to its latitude and longitude.
Args:
location (str): The name of the location to geocode
Returns:
tuple[float, float]: A tuple containing (latitude, longitude)
Raises:
AttributeError: If the location cannot be found
"""
geolocator = Nominatim(user_agent="city_to_latlong", timeout=5)
coords = geolocator.geocode(location)
return (coords.latitude, coords.longitude)
def coords_to_country(coords: tuple[float, float]) -> tuple[str,str]:
"""Converts geographic coordinates to a country name.
This function uses the Nominatim reverse geocoding service to convert
latitude and longitude coordinates to a country name.
Args:
coords (tuple[float, float]): A tuple containing (latitude, longitude)
Returns:
tuple[str,str]: A tuple containg (country_code, country_name, admin1)
Raises:
AttributeError: If the coordinates cannot be found
"""
geolocator = Nominatim(user_agent="latlong_to_country")
location = geolocator.reverse(coords)
address = location.raw['address']
return address['country_code'].upper(), address['country']
def nearest_neighbour_sql(location: tuple, mode: Literal['DRIAS', 'IPCC']) -> tuple[str, str, Optional[str]]:
long = round(location[1], 3)
lat = round(location[0], 3)
conn = duckdb.connect()
if mode == 'DRIAS':
table_path = f"'{DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH}'"
results = conn.sql(
f"SELECT latitude, longitude FROM {table_path} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
).fetchdf()
else:
table_path = f"'{IPCC_COORDINATES_PATH}'"
results = conn.sql(
f"SELECT latitude, longitude, admin1 FROM {table_path} WHERE latitude BETWEEN {lat - 0.5} AND {lat + 0.5} AND longitude BETWEEN {long - 0.5} AND {long + 0.5}"
).fetchdf()
if len(results) == 0:
return "", "", ""
if 'admin1' in results.columns:
admin1 = results['admin1'].iloc[0]
else:
admin1 = None
return results['latitude'].iloc[0], results['longitude'].iloc[0], admin1
async def detect_year_with_openai(sentence: str) -> str:
"""
Detects years in a sentence using OpenAI's API via LangChain.
"""
llm = get_llm()
prompt = """
Extract all years mentioned in the following sentence.
Return the result as a Python list. If no year are mentioned, return an empty list.
Sentence: "{sentence}"
"""
prompt = ChatPromptTemplate.from_template(prompt)
structured_llm = llm.with_structured_output(ArrayOutput)
chain = prompt | structured_llm
response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
years_list = ast.literal_eval(response['array'])
if len(years_list) > 0:
return years_list[0]
else:
return ""
async def detect_relevant_tables(user_question: str, plot: Plot, llm, table_names_list: list[str]) -> list[str]:
"""Identifies relevant tables for a plot based on user input.
This function uses an LLM to analyze the user's question and the plot
description to determine which tables in the DRIAS database would be
most relevant for generating the requested visualization.
Args:
user_question (str): The user's question about climate data
plot (Plot): The plot configuration object
llm: The language model instance to use for analysis
Returns:
list[str]: A list of table names that are relevant for the plot
Example:
>>> detect_relevant_tables(
... "What will the temperature be like in Paris?",
... indicator_evolution_at_location,
... llm
... )
['mean_annual_temperature', 'mean_summer_temperature']
"""
# Get all table names
prompt = (
f"You are helping to build a plot following this description : {plot['description']}."
f"You are given a list of tables and a user question."
f"Based on the description of the plot, which table are appropriate for that kind of plot."
f"Write the 3 most relevant tables to use. Answer only a python list of table name."
f"### List of tables : {table_names_list}"
f"### User question : {user_question}"
f"### List of table name : "
)
table_names = ast.literal_eval(
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
)
return table_names
async def detect_relevant_plots(user_question: str, llm, plot_list: list[Plot]) -> list[str]:
plots_description = ""
for plot in plot_list:
plots_description += "Name: " + plot["name"]
plots_description += " - Description: " + plot["description"] + "\n"
prompt = (
"You are helping to answer a question with insightful visualizations.\n"
"You are given a user question and a list of plots with their name and description.\n"
"Based on the descriptions of the plots, select ALL plots that could provide a useful answer to this question. "
"Include any plot that could show relevant information, even if their perspectives (such as time series or spatial distribution) are different.\n"
"For example, for a question like 'What will be the total rainfall in China in 2050?', both a time series plot and a spatial map plot could be relevant.\n"
"Return only a Python list of plot names sorted from the most relevant one to the less relevant one.\n"
f"### Descriptions of the plots : {plots_description}"
f"### User question : {user_question}\n"
f"### Names of the plots : "
)
plot_names = ast.literal_eval(
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
)
return plot_names
async def find_location(user_input: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> Location:
print(f"---- Find location in user input ----")
location = await detect_location_with_openai(user_input)
output: Location = {
'location' : location,
'longitude' : None,
'latitude' : None,
'country_code' : None,
'country_name' : None,
'admin1' : None
}
if location:
coords = loc_to_coords(location)
country_code, country_name = coords_to_country(coords)
neighbour = nearest_neighbour_sql(coords, mode)
output.update({
"latitude": neighbour[0],
"longitude": neighbour[1],
"country_code": country_code,
"country_name": country_name,
"admin1": neighbour[2]
})
output = cast(Location, output)
return output
async def find_year(user_input: str) -> str| None:
"""Extracts year information from user input using LLM.
This function uses an LLM to identify and extract year information from the
user's query, which is used to filter data in subsequent queries.
Args:
user_input (str): The user's query text
Returns:
str: The extracted year, or empty string if no year found
"""
print(f"---- Find year ---")
year = await detect_year_with_openai(user_input)
if year == "":
return None
return year
async def find_month(user_input: str) -> dict[str, str|None]:
"""
Extracts month information from user input using an LLM.
This function analyzes the user's query to detect if a month is mentioned.
It returns both the month number (as a string, e.g. '7' for July) and the full English month name (e.g. 'July').
If no month is found, both values will be None.
Args:
user_input (str): The user's query text.
Returns:
dict[str, str|None]: A dictionary with keys:
- "month_number": the month number as a string (e.g. '7'), or None if not found
- "month_name": the full English month name (e.g. 'July'), or None if not found
Example:
>>> await find_month("Show me the temperature in Paris in July")
{'month_number': '7', 'month_name': 'July'}
>>> await find_month("Show me the temperature in Paris")
{'month_number': None, 'month_name': None}
"""
llm = get_llm()
prompt = """
Extract the month (as a number from 1 to 12) mentioned in the following sentence.
Return the result as a Python list of integers. If no month is mentioned, return an empty list.
Sentence: "{sentence}"
"""
prompt = ChatPromptTemplate.from_template(prompt)
structured_llm = llm.with_structured_output(ArrayOutput)
chain = prompt | structured_llm
response: ArrayOutput = await chain.ainvoke({"sentence": user_input})
months_list = ast.literal_eval(response['array'])
if len(months_list) > 0:
month_number = int(months_list[0])
month_name = calendar.month_name[month_number]
return {
"month_number": str(month_number),
"month_name": month_name
}
else:
return {
"month_number" : None,
"month_name" : None
}
async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]:
print("---- Find relevant plots ----")
relevant_plots = await detect_relevant_plots(state['user_input'], llm, plots)
return relevant_plots
async def find_relevant_tables_per_plot(state: State, plot: Plot, llm, tables: list[str]) -> list[str]:
print(f"---- Find relevant tables for {plot['name']} ----")
relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables)
return relevant_tables
async def find_param(state: State, param_name: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None:
"""
Retrieves a specific parameter (location, year, month, etc.) from the user's input using the appropriate extraction method.
Args:
state (State): The current state containing at least the user's input under 'user_input'.
param_name (str): The name of the parameter to extract. Supported: 'location', 'year', 'month'.
mode (Literal['DRIAS', 'IPCC']): The data mode to use for location extraction.
Returns:
- For 'location': a Location object (dict with keys like 'location', 'latitude', etc.), or None if not found.
- For 'year': a dict {'year': year or None}.
- For 'month': a dict {'month_number': str or None, 'month_name': str or None}.
- None if the parameter is not recognized or not found.
Example:
>>> await find_param(state, 'location')
{'location': 'Paris', 'latitude': ..., ...}
>>> await find_param(state, 'year')
{'year': '2050'}
>>> await find_param(state, 'month')
{'month_number': '7', 'month_name': 'July'}
"""
if param_name == 'location':
location = await find_location(state['user_input'], mode)
return location
if param_name == 'year':
year = await find_year(state['user_input'])
return {'year': year}
if param_name == 'month':
month = await find_month(state['user_input'])
return month
return None