|
import os |
|
import io |
|
import re |
|
import gradio as gr |
|
import pandas as pd |
|
import openai |
|
import matplotlib.pyplot as plt |
|
from dotenv import load_dotenv |
|
from PIL import Image |
|
import traceback |
|
|
|
|
|
load_dotenv() |
|
openai.api_key = os.getenv("OPENAI_API_KEY") |
|
|
|
def load_file(file): |
|
"""Load a CSV or Excel file into a pandas DataFrame.""" |
|
if file is None: |
|
return None |
|
|
|
|
|
if isinstance(file, dict): |
|
file_name = file.get("name", "").lower() |
|
file_path = file.get("data", None) |
|
if file_path is None: |
|
return None |
|
try: |
|
if file_name.endswith('.csv'): |
|
df = pd.read_csv(file_path) |
|
elif file_name.endswith('.xlsx'): |
|
df = pd.read_excel(file_path, engine='openpyxl') |
|
elif file_name.endswith('.xls'): |
|
df = pd.read_excel(file_path, engine='xlrd') |
|
else: |
|
return None |
|
except Exception as e: |
|
print("Error loading file from dict:", e) |
|
return None |
|
return df |
|
else: |
|
|
|
file_name = file.name.lower() |
|
try: |
|
if file_name.endswith('.csv'): |
|
df = pd.read_csv(file) |
|
elif file_name.endswith('.xlsx'): |
|
df = pd.read_excel(file, engine='openpyxl') |
|
elif file_name.endswith('.xls'): |
|
df = pd.read_excel(file, engine='xlrd') |
|
else: |
|
return None |
|
except Exception as e: |
|
print("Error loading file from file-like object:", e) |
|
return None |
|
return df |
|
|
|
def preview_file(file): |
|
"""Return the DataFrame for preview.""" |
|
df = load_file(file) |
|
if df is None: |
|
|
|
return pd.DataFrame({"Error": ["Error loading file or unsupported file type."]}) |
|
return df |
|
|
|
def generate_basic_understanding_code(df_preview): |
|
""" |
|
Generate Python code that performs an exploratory analysis on the DataFrame. |
|
The generated code should output a variable 'basic_info' that is a dictionary containing: |
|
- The data types of each column. |
|
- For numeric columns, summary statistics (mean, median, std, etc.). |
|
- For non-numeric columns, counts, unique values, mode, and frequency distributions. |
|
If charts are generated, ensure plt.show() is called after each chart so they can be captured. |
|
Note: When converting dates, use pd.to_datetime() without a fixed format or with dayfirst=True. |
|
""" |
|
prompt = f""" |
|
You are a data analysis expert. Write Python code that performs an exploratory analysis of the DataFrame. |
|
Assume a pandas DataFrame named 'df' is already loaded. |
|
Output only raw Python code without any markdown formatting or code fences. |
|
Assign the exploratory summary to a variable named 'basic_info' as a dictionary. |
|
For each column in df, include its data type. |
|
- For numeric columns (use pd.api.types.is_numeric_dtype), include summary statistics (mean, median, std, etc.). |
|
- For non-numeric columns, treat them as categorical variables and include counts, unique values, mode, and frequency distributions. |
|
When converting date strings to datetime, use pd.to_datetime() without a fixed format or with dayfirst=True. |
|
If your analysis includes charts, call plt.show() after each chart so they can be captured. |
|
Only reference columns that are present in df.columns. |
|
|
|
Note: The following safe built-ins are available: list, dict, set, tuple, abs, min, max, sum, len, range, print, pd, plt, __import__. |
|
|
|
DataFrame preview: |
|
Columns: {list(df_preview.columns)} |
|
Sample Data (first 3 rows): |
|
{df_preview.head(3).to_dict(orient='records')} |
|
""" |
|
response = openai.chat.completions.create( |
|
model="gpt-4o-mini", |
|
messages=[ |
|
{"role": "system", "content": "You are an expert data analysis assistant who outputs only raw Python code."}, |
|
{"role": "user", "content": prompt} |
|
], |
|
temperature=0.3, |
|
max_tokens=3500, |
|
) |
|
code = response.choices[0].message.content.strip() |
|
return code |
|
|
|
def generate_problem_solving_code(nl_query, df_preview, basic_info): |
|
""" |
|
Generate Python code that solves the user's analysis query. |
|
The code should assume that the DataFrame 'df' is loaded and that the variable 'basic_info' |
|
(the output from the initial exploratory analysis) is available. |
|
The final analysis should be assigned to a variable named 'result' as a dictionary with keys: |
|
'summary', 'detailed_stats', 'insights', and 'chart_descriptions'. |
|
If charts are generated, call plt.show() after each chart so they can be captured. |
|
Note: When converting date strings to datetime, use pd.to_datetime() without a fixed format or with dayfirst=True. |
|
Only reference columns that are present in df.columns. |
|
""" |
|
prompt = f""" |
|
You are a data analysis expert. Write Python code that performs the analysis as described below. |
|
Assume a pandas DataFrame named 'df' is already loaded and that you have already generated an exploratory summary stored in 'basic_info'. |
|
Output only raw Python code without any markdown formatting or code fences. |
|
Ensure that the final output is assigned to a variable named 'result' as a dictionary with the following keys: 'summary', 'detailed_stats', 'insights', and 'chart_descriptions'. The analysis should be verbose and include all relevant statistics, interpretations, and intermediate steps. |
|
When processing the DataFrame, first inspect each column’s data type: |
|
- For numeric columns (use pd.api.types.is_numeric_dtype), compute numeric statistics (mean, median, standard deviation, etc.). |
|
- For non-numeric columns, treat them as categorical variables and compute appropriate descriptive statistics (counts, unique values, mode, and frequency distributions). |
|
- Only generate charts and tables that are relevant to the problem at hand. Exclude fields that are not relevant to the problem from the charts and tables. |
|
Incorporate insights from 'basic_info' if relevant. |
|
When converting date strings to datetime, use pd.to_datetime() without a fixed format or with dayfirst=True. |
|
If your analysis includes charts, call plt.show() after each chart so they can be captured. |
|
Only reference columns that are present in df.columns. |
|
|
|
Note: The following safe built-ins are available: list, dict, set, tuple, abs, min, max, sum, len, range, print, pd, plt, __import__. |
|
|
|
DataFrame preview: |
|
Columns: {list(df_preview.columns)} |
|
Sample Data (first 3 rows): |
|
{df_preview.head(3).to_dict(orient='records')} |
|
|
|
User Query: "{nl_query}" |
|
""" |
|
response = openai.chat.completions.create( |
|
model="gpt-4o-mini", |
|
messages=[ |
|
{"role": "system", "content": "You are an expert data analysis assistant who outputs only raw Python code."}, |
|
{"role": "user", "content": prompt} |
|
], |
|
temperature=0.3, |
|
max_tokens=3500, |
|
) |
|
code = response.choices[0].message.content.strip() |
|
return code |
|
|
|
def validate_generated_code(code, df): |
|
""" |
|
Validate that the generated code references only columns that exist in the DataFrame. |
|
This function uses a regex to find patterns like df['column'] and checks if "column" exists. |
|
""" |
|
pattern = re.compile(r"df\[['\"]([^'\"]+)['\"]\]") |
|
referenced_cols = pattern.findall(code) |
|
missing_cols = [col for col in referenced_cols if col not in df.columns] |
|
if missing_cols: |
|
return False, missing_cols |
|
return True, [] |
|
|
|
def safe_exec_code(code, df, capture_charts=True, interactive=False, extra_globals=None): |
|
""" |
|
Execute the generated code in a restricted namespace. |
|
Returns a tuple (output, charts) where output is the value of 'result' or 'basic_info'. |
|
In case of an error, returns a detailed error message. |
|
""" |
|
|
|
code_lines = code.splitlines() |
|
clean_lines = [line for line in code_lines if not line.strip().startswith("```")] |
|
clean_code = "\n".join(clean_lines).strip() |
|
|
|
|
|
valid, missing_cols = validate_generated_code(clean_code, df) |
|
if not valid: |
|
return (f"Generated code references missing columns: {missing_cols}\nPlease adjust your prompt or data.", |
|
[]) |
|
|
|
|
|
safe_builtins = { |
|
"abs": abs, |
|
"min": min, |
|
"max": max, |
|
"sum": sum, |
|
"len": len, |
|
"range": range, |
|
"print": print, |
|
"list": list, |
|
"dict": dict, |
|
"set": set, |
|
"tuple": tuple, |
|
"sorted": sorted, |
|
"zip": zip, |
|
"enumerate": enumerate, |
|
"pd": pd, |
|
"plt": plt, |
|
"str": str, |
|
"float": float, |
|
"int": int, |
|
"bool": bool, |
|
"complex": complex, |
|
"round": round, |
|
"__import__": __import__, |
|
} |
|
safe_globals = {"__builtins__": safe_builtins, "df": df, "plt": plt, "charts": []} |
|
|
|
|
|
try: |
|
import seaborn as sns |
|
safe_globals["sns"] = sns |
|
except ImportError: |
|
pass |
|
|
|
if extra_globals is not None: |
|
safe_globals.update(extra_globals) |
|
safe_locals = {} |
|
|
|
if capture_charts: |
|
def custom_show(*args, **kwargs): |
|
buf = io.BytesIO() |
|
plt.savefig(buf, format="png") |
|
buf.seek(0) |
|
img = Image.open(buf).convert("RGB") |
|
safe_globals["charts"].append(img) |
|
plt.close() |
|
safe_globals["plt"].show = custom_show |
|
|
|
try: |
|
|
|
exec(clean_code, safe_globals, safe_locals) |
|
output = safe_locals.get("result", None) |
|
if output is None: |
|
output = safe_locals.get("basic_info", None) |
|
except Exception as ex: |
|
error_details = traceback.format_exc() |
|
if "ValueError: time data" in error_details: |
|
error_details += "\nHint: The generated code might be using a fixed datetime format. Consider using pd.to_datetime() without a fixed format or with dayfirst=True." |
|
if "KeyError" in error_details: |
|
error_details += "\nHint: The generated code might be referencing columns that do not exist in your DataFrame." |
|
if "NameError" in error_details: |
|
error_details += "\nHint: Ensure that all required built-in types and libraries (like float, int, etc.) are included in the safe built-ins." |
|
return f"An error occurred during code execution:\n{error_details}", safe_globals["charts"] |
|
|
|
if capture_charts and not safe_globals["charts"]: |
|
fig_nums = plt.get_fignums() |
|
for num in fig_nums: |
|
fig = plt.figure(num) |
|
buf = io.BytesIO() |
|
fig.savefig(buf, format="png") |
|
buf.seek(0) |
|
img = Image.open(buf).convert("RGB") |
|
safe_globals["charts"].append(img) |
|
plt.close("all") |
|
|
|
if interactive: |
|
for img in safe_globals["charts"]: |
|
img.show() |
|
|
|
if output is None: |
|
output = "No output variable ('result' or 'basic_info') was set by the code." |
|
return output, safe_globals["charts"] |
|
|
|
def generate_interpretation(analysis_result, nl_query): |
|
""" |
|
Use OpenAI to generate a detailed interpretation of the analysis result. |
|
Provide context from the user's query and explain what the results mean. |
|
The response will be formatted in markdown. |
|
""" |
|
prompt = f""" |
|
You are a knowledgeable data analyst. Based on the following analysis result and the user's query, provide a detailed interpretation and descriptive analysis of the results. Explain what the results mean, any insights that can be drawn, and any potential limitations. |
|
Please format your output in markdown (including headers, bullet points, and other markdown formatting as appropriate). |
|
|
|
User Query: "{nl_query}" |
|
|
|
Analysis Result: |
|
{analysis_result} |
|
|
|
Provide a clear and detailed explanation in plain language. |
|
""" |
|
response = openai.chat.completions.create( |
|
model="gpt-4o-mini", |
|
messages=[ |
|
{"role": "system", "content": "You are an expert data analysis assistant who explains analysis results clearly."}, |
|
{"role": "user", "content": prompt} |
|
], |
|
temperature=0.5, |
|
max_tokens=5000, |
|
) |
|
interpretation = response.choices[0].message.content.strip() |
|
return interpretation |
|
|
|
def generate_and_run(nl_query, file, interactive_mode=False): |
|
""" |
|
Load the file, generate both a basic understanding and a detailed analysis code using OpenAI, |
|
execute the generated code, and then generate an interpretation of the analysis result. |
|
Returns a tuple: (analysis result, combined generated code, DataFrame preview, charts, interpretation). |
|
|
|
The process is split into two steps: |
|
1. Generate basic understanding code that produces 'basic_info'. |
|
2. Generate problem-solving code that uses 'basic_info' and produces the final analysis ('result'). |
|
""" |
|
df = load_file(file) |
|
if df is None: |
|
return "Error loading file.", "", pd.DataFrame({"Error": ["No data available."]}), [], "" |
|
|
|
df_preview = df.copy() |
|
|
|
basic_code = generate_basic_understanding_code(df_preview) |
|
basic_info, basic_charts = safe_exec_code(basic_code, df, capture_charts=False, interactive=interactive_mode) |
|
|
|
|
|
problem_code = generate_problem_solving_code(nl_query, df_preview, basic_info) |
|
result, problem_charts = safe_exec_code(problem_code, df, capture_charts=True, interactive=interactive_mode, extra_globals={"basic_info": basic_info}) |
|
|
|
interpretation = generate_interpretation(result, nl_query) |
|
combined_code = f"### Basic Understanding Code:\n{basic_code}\n\n### Problem Solving Code:\n{problem_code}" |
|
combined_charts = basic_charts + problem_charts |
|
return result, combined_code, df_preview, combined_charts, interpretation |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Dynamic Data Analysis with Two-Step Code Generation and Interpretation") |
|
|
|
with gr.Tab("Data Upload & Preview"): |
|
file_input = gr.File(label="Upload CSV or Excel file (.csv, .xls, .xlsx)") |
|
data_preview = gr.Dataframe(label="Data Preview") |
|
file_input.change(fn=preview_file, inputs=file_input, outputs=data_preview) |
|
|
|
with gr.Tab("Generate & Execute Analysis (Gradio Mode)"): |
|
nl_query = gr.Textbox( |
|
label="Enter your query", |
|
placeholder="e.g., Generate summary statistics and charts for Gender and Age distributions" |
|
) |
|
generate_btn = gr.Button("Generate & Execute Code") |
|
analysis_output = gr.Textbox(label="Analysis Result", lines=10) |
|
code_output = gr.Code(label="Generated Code", language="python") |
|
preview_output = gr.Dataframe(label="Data Preview") |
|
charts_output = gr.Gallery(label="Charts", show_label=True) |
|
interpretation_output = gr.Markdown(label="Interpretation") |
|
|
|
generate_btn.click( |
|
fn=lambda query, file: generate_and_run(query, file, interactive_mode=True), |
|
inputs=[nl_query, file_input], |
|
outputs=[analysis_output, code_output, preview_output, charts_output, interpretation_output] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|
|
|
|
|