|
import os
|
|
import pandas as pd
|
|
from pandasai import Agent, SmartDataframe
|
|
from typing import Tuple
|
|
from PIL import Image
|
|
from pandasai.llm import HuggingFaceTextGen
|
|
from dotenv import load_dotenv
|
|
from langchain_groq import ChatGroq
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
load_dotenv(override=True)
|
|
|
|
|
|
Groq_Token = os.getenv("GROQ_API_KEY")
|
|
hf_token = os.getenv("HF_TOKEN")
|
|
gemini_token = os.getenv("GEMINI_TOKEN")
|
|
|
|
|
|
print(f"Debug - Groq Token: {'Present' if Groq_Token else 'Missing'}")
|
|
print(f"Debug - Groq Token Value: {Groq_Token[:10] + '...' if Groq_Token else 'None'}")
|
|
print(f"Debug - Gemini Token: {'Present' if gemini_token else 'Missing'}")
|
|
|
|
models = {
|
|
"mistral": "mistral-saba-24b",
|
|
"llama3.3": "llama-3.3-70b-versatile",
|
|
"llama3.1": "llama-3.1-8b-instant",
|
|
"gemma2": "gemma2-9b-it",
|
|
"gemini-pro": "gemini-1.5-pro"
|
|
}
|
|
|
|
def preprocess_and_load_df(path: str) -> pd.DataFrame:
|
|
"""Load and preprocess the dataframe"""
|
|
try:
|
|
df = pd.read_csv(path)
|
|
df["Timestamp"] = pd.to_datetime(df["Timestamp"])
|
|
return df
|
|
except Exception as e:
|
|
raise Exception(f"Error loading dataframe: {e}")
|
|
|
|
def load_agent(df: pd.DataFrame, context: str, inference_server: str, name="mistral") -> Agent:
|
|
"""Load pandas AI agent with error handling"""
|
|
try:
|
|
if name == "gemini-pro":
|
|
if not gemini_token or gemini_token.strip() == "":
|
|
raise ValueError("Gemini API token not available or empty")
|
|
llm = ChatGoogleGenerativeAI(
|
|
model=models[name],
|
|
google_api_key=gemini_token,
|
|
temperature=0.1
|
|
)
|
|
else:
|
|
if not Groq_Token or Groq_Token.strip() == "":
|
|
raise ValueError("Groq API token not available or empty")
|
|
llm = ChatGroq(
|
|
model=models[name],
|
|
api_key=Groq_Token,
|
|
temperature=0.1
|
|
)
|
|
|
|
agent = Agent(df, config={"llm": llm, "enable_cache": False, "options": {"wait_for_model": True}})
|
|
if context:
|
|
agent.add_message(context)
|
|
return agent
|
|
except Exception as e:
|
|
raise Exception(f"Error loading agent: {e}")
|
|
|
|
def load_smart_df(df: pd.DataFrame, inference_server: str, name="mistral") -> SmartDataframe:
|
|
"""Load smart dataframe with error handling"""
|
|
try:
|
|
if name == "gemini-pro":
|
|
if not gemini_token or gemini_token.strip() == "":
|
|
raise ValueError("Gemini API token not available or empty")
|
|
llm = ChatGoogleGenerativeAI(
|
|
model=models[name],
|
|
google_api_key=gemini_token,
|
|
temperature=0.1
|
|
)
|
|
else:
|
|
if not Groq_Token or Groq_Token.strip() == "":
|
|
raise ValueError("Groq API token not available or empty")
|
|
llm = ChatGroq(
|
|
model=models[name],
|
|
api_key=Groq_Token,
|
|
temperature=0.1
|
|
)
|
|
|
|
df = SmartDataframe(df, config={"llm": llm, "max_retries": 5, "enable_cache": False})
|
|
return df
|
|
except Exception as e:
|
|
raise Exception(f"Error loading smart dataframe: {e}")
|
|
|
|
def get_from_user(prompt):
|
|
"""Format user prompt"""
|
|
return {"role": "user", "content": prompt}
|
|
|
|
def ask_agent(agent: Agent, prompt: str) -> dict:
|
|
"""Ask agent with comprehensive error handling"""
|
|
try:
|
|
response = agent.chat(prompt)
|
|
gen_code = getattr(agent, 'last_code_generated', '')
|
|
ex_code = getattr(agent, 'last_code_executed', '')
|
|
last_prompt = getattr(agent, 'last_prompt', prompt)
|
|
|
|
return {
|
|
"role": "assistant",
|
|
"content": response,
|
|
"gen_code": gen_code,
|
|
"ex_code": ex_code,
|
|
"last_prompt": last_prompt,
|
|
"error": None
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"role": "assistant",
|
|
"content": f"Error: {str(e)}",
|
|
"gen_code": "",
|
|
"ex_code": "",
|
|
"last_prompt": prompt,
|
|
"error": str(e)
|
|
}
|
|
|
|
def decorate_with_code(response: dict) -> str:
|
|
"""Decorate response with code details"""
|
|
gen_code = response.get("gen_code", "No code generated")
|
|
last_prompt = response.get("last_prompt", "No prompt")
|
|
|
|
return f"""<details>
|
|
<summary>Generated Code</summary>
|
|
|
|
```python
|
|
{gen_code}
|
|
```
|
|
</details>
|
|
|
|
<details>
|
|
<summary>Prompt</summary>
|
|
|
|
{last_prompt}
|
|
"""
|
|
|
|
def show_response(st, response):
|
|
"""Display response with error handling"""
|
|
try:
|
|
with st.chat_message(response["role"]):
|
|
content = response.get("content", "No content")
|
|
|
|
try:
|
|
|
|
image = Image.open(content)
|
|
if response.get("gen_code"):
|
|
st.markdown(decorate_with_code(response), unsafe_allow_html=True)
|
|
st.image(image)
|
|
return {"is_image": True}
|
|
except:
|
|
|
|
if response.get("gen_code"):
|
|
display_content = decorate_with_code(response) + f"""</details>
|
|
|
|
{content}"""
|
|
else:
|
|
display_content = content
|
|
st.markdown(display_content, unsafe_allow_html=True)
|
|
return {"is_image": False}
|
|
except Exception as e:
|
|
st.error(f"Error displaying response: {e}")
|
|
return {"is_image": False}
|
|
|
|
def ask_question(model_name, question):
|
|
"""Ask question with comprehensive error handling"""
|
|
try:
|
|
|
|
load_dotenv(override=True)
|
|
fresh_groq_token = os.getenv("GROQ_API_KEY")
|
|
fresh_gemini_token = os.getenv("GEMINI_TOKEN")
|
|
|
|
print(f"ask_question - Fresh Groq Token: {'Present' if fresh_groq_token else 'Missing'}")
|
|
|
|
|
|
if model_name == "gemini-pro":
|
|
if not fresh_gemini_token or fresh_gemini_token.strip() == "":
|
|
return {
|
|
"role": "assistant",
|
|
"content": "β Gemini API token not available or empty. Please set GEMINI_TOKEN in your environment variables.",
|
|
"gen_code": "",
|
|
"ex_code": "",
|
|
"last_prompt": question,
|
|
"error": "Missing or empty API token"
|
|
}
|
|
llm = ChatGoogleGenerativeAI(
|
|
model=models[model_name],
|
|
google_api_key=fresh_gemini_token,
|
|
temperature=0
|
|
)
|
|
else:
|
|
if not fresh_groq_token or fresh_groq_token.strip() == "":
|
|
return {
|
|
"role": "assistant",
|
|
"content": "β Groq API token not available or empty. Please set GROQ_API_KEY in your environment variables and restart the application.",
|
|
"gen_code": "",
|
|
"ex_code": "",
|
|
"last_prompt": question,
|
|
"error": "Missing or empty API token"
|
|
}
|
|
|
|
|
|
try:
|
|
llm = ChatGroq(
|
|
model=models[model_name],
|
|
api_key=fresh_groq_token,
|
|
temperature=0.1
|
|
)
|
|
|
|
test_response = llm.invoke("Test")
|
|
print("API key test successful")
|
|
except Exception as api_error:
|
|
error_msg = str(api_error).lower()
|
|
if "organization_restricted" in error_msg or "unauthorized" in error_msg:
|
|
return {
|
|
"role": "assistant",
|
|
"content": "β API Key Error: Your Groq API key appears to be invalid, expired, or restricted. Please check your API key in the .env file.",
|
|
"gen_code": "",
|
|
"ex_code": "",
|
|
"last_prompt": question,
|
|
"error": f"API key validation failed: {str(api_error)}"
|
|
}
|
|
else:
|
|
return {
|
|
"role": "assistant",
|
|
"content": f"β API Connection Error: {str(api_error)}",
|
|
"gen_code": "",
|
|
"ex_code": "",
|
|
"last_prompt": question,
|
|
"error": str(api_error)
|
|
}
|
|
|
|
|
|
if not os.path.exists("Data.csv"):
|
|
return {
|
|
"role": "assistant",
|
|
"content": "β Data.csv file not found. Please ensure the data file is in the correct location.",
|
|
"gen_code": "",
|
|
"ex_code": "",
|
|
"last_prompt": question,
|
|
"error": "Data file not found"
|
|
}
|
|
|
|
df_check = pd.read_csv("Data.csv")
|
|
df_check["Timestamp"] = pd.to_datetime(df_check["Timestamp"])
|
|
df_check = df_check.head(5)
|
|
|
|
new_line = "\n"
|
|
parameters = {"font.size": 12, "figure.dpi": 600}
|
|
|
|
template = f"""```python
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
import uuid
|
|
|
|
plt.rcParams.update({parameters})
|
|
|
|
df = pd.read_csv("Data.csv")
|
|
df["Timestamp"] = pd.to_datetime(df["Timestamp"])
|
|
|
|
# Available columns and data types:
|
|
{new_line.join(map(lambda x: '# '+x, str(df_check.dtypes).split(new_line)))}
|
|
|
|
# Question: {question.strip()}
|
|
# Generate code to answer the question and save result in 'answer' variable
|
|
# If creating a plot, save it with a unique filename and store the filename in 'answer'
|
|
# If returning text/numbers, store the result directly in 'answer'
|
|
```"""
|
|
|
|
system_prompt = """You are a helpful assistant that generates Python code for data analysis.
|
|
|
|
Rules:
|
|
1. Always save your final result in a variable called 'answer'
|
|
2. If creating a plot, save it with plt.savefig() and store the filename in 'answer'
|
|
3. If returning text/numbers, store the result directly in 'answer'
|
|
4. Use descriptive variable names and add comments
|
|
5. Handle potential errors gracefully
|
|
6. For plots, use unique filenames to avoid conflicts
|
|
"""
|
|
|
|
query = f"""{system_prompt}
|
|
|
|
Complete the following code to answer the user's question:
|
|
|
|
{template}
|
|
"""
|
|
|
|
|
|
if model_name == "gemini-pro":
|
|
response = llm.invoke(query)
|
|
answer = response.content
|
|
else:
|
|
response = llm.invoke(query)
|
|
answer = response.content
|
|
|
|
|
|
try:
|
|
if "```python" in answer:
|
|
code_part = answer.split("```python")[1].split("```")[0]
|
|
else:
|
|
code_part = answer
|
|
|
|
full_code = f"""
|
|
{template.split("```python")[1].split("```")[0]}
|
|
{code_part}
|
|
"""
|
|
|
|
|
|
local_vars = {}
|
|
global_vars = {
|
|
'pd': pd,
|
|
'plt': plt,
|
|
'os': os,
|
|
'uuid': __import__('uuid')
|
|
}
|
|
|
|
exec(full_code, global_vars, local_vars)
|
|
|
|
|
|
if 'answer' in local_vars:
|
|
answer_result = local_vars['answer']
|
|
else:
|
|
answer_result = "No answer variable found in generated code"
|
|
|
|
return {
|
|
"role": "assistant",
|
|
"content": answer_result,
|
|
"gen_code": full_code,
|
|
"ex_code": full_code,
|
|
"last_prompt": question,
|
|
"error": None
|
|
}
|
|
|
|
except Exception as code_error:
|
|
return {
|
|
"role": "assistant",
|
|
"content": f"β Error executing generated code: {str(code_error)}",
|
|
"gen_code": full_code if 'full_code' in locals() else "",
|
|
"ex_code": full_code if 'full_code' in locals() else "",
|
|
"last_prompt": question,
|
|
"error": str(code_error)
|
|
}
|
|
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
|
|
|
|
if "organization_restricted" in error_msg:
|
|
return {
|
|
"role": "assistant",
|
|
"content": "β API Organization Restricted: Your API key access has been restricted. Please check your Groq API key or try generating a new one.",
|
|
"gen_code": "",
|
|
"ex_code": "",
|
|
"last_prompt": question,
|
|
"error": "API access restricted"
|
|
}
|
|
elif "rate_limit" in error_msg.lower():
|
|
return {
|
|
"role": "assistant",
|
|
"content": "β Rate limit exceeded. Please wait a moment and try again.",
|
|
"gen_code": "",
|
|
"ex_code": "",
|
|
"last_prompt": question,
|
|
"error": "Rate limit exceeded"
|
|
}
|
|
else:
|
|
return {
|
|
"role": "assistant",
|
|
"content": f"β Error: {error_msg}",
|
|
"gen_code": "",
|
|
"ex_code": "",
|
|
"last_prompt": question,
|
|
"error": error_msg
|
|
} |