Spaces:
Running
Running
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import plotly.express as px | |
import plotly.graph_objects as go | |
from plotly.subplots import make_subplots | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
import io | |
import base64 | |
from scipy import stats | |
import warnings | |
import google.generativeai as genai | |
import os | |
from dotenv import load_dotenv | |
import logging | |
from datetime import datetime | |
import tempfile | |
import json | |
warnings.filterwarnings('ignore') | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
# Load environment variables | |
#load_dotenv() | |
# Gemini API configuration | |
# Set your API key as environment variable: GEMINI_API_KEY | |
#genai.configure(api_key=os.getenv("GEMINI_API_KEY")) | |
def analyze_dataset_overview(file_obj, api_key) -> tuple: | |
""" | |
Analyzes dataset using Gemini AI and provides storytelling overview. | |
Args: | |
file_obj: Gradio file object | |
api_key: Gemini API key from user input | |
Returns: | |
story_text (str): AI-generated data story | |
basic_info_text (str): Dataset basic information | |
data_quality_score (float): Data quality percentage | |
""" | |
if file_obj is None: | |
return "❌ Please upload a CSV file first.", "", 0 | |
if not api_key or api_key.strip() == "": | |
return "❌ Please enter your Gemini API key first.", "", 0 | |
try: | |
df = pd.read_csv(file_obj.name) | |
# Extract dataset metadata | |
metadata = extract_dataset_metadata(df) | |
# Create prompt for Gemini | |
gemini_prompt = create_insights_prompt(metadata) | |
# Generate story with Gemini | |
story = generate_insights_with_gemini(gemini_prompt, api_key) | |
# Create basic info summary | |
basic_info = create_basic_info_summary(metadata) | |
# Calculate data quality score | |
quality_score = metadata['data_quality'] | |
return story, basic_info, quality_score | |
except Exception as e: | |
return f"❌ Error loading data: {str(e)}", "", 0 | |
def extract_dataset_metadata(df: pd.DataFrame) -> dict: | |
""" | |
Extracts metadata from dataset. | |
Args: | |
df (pd.DataFrame): DataFrame to analyze | |
Returns: | |
dict: Dataset metadata | |
""" | |
rows, cols = df.shape | |
columns = df.columns.tolist() | |
# Data types | |
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist() | |
categorical_cols = df.select_dtypes(include=['object']).columns.tolist() | |
datetime_cols = df.select_dtypes(include=['datetime64']).columns.tolist() | |
# Missing values | |
missing_data = df.isnull().sum() | |
missing_percentage = (missing_data / len(df) * 100).round(2) | |
# Basic statistics | |
numeric_stats = {} | |
if numeric_cols: | |
numeric_stats = df[numeric_cols].describe().to_dict() | |
# Categorical variable information | |
categorical_info = {} | |
for col in categorical_cols[:5]: # First 5 categorical columns | |
unique_count = df[col].nunique() | |
top_values = df[col].value_counts().head(3).to_dict() | |
categorical_info[col] = { | |
'unique_count': unique_count, | |
'top_values': top_values | |
} | |
# Potential relationships | |
correlations = {} | |
if len(numeric_cols) > 1: | |
corr_matrix = df[numeric_cols].corr() | |
# Find highest correlations | |
high_corr = [] | |
for i in range(len(corr_matrix.columns)): | |
for j in range(i+1, len(corr_matrix.columns)): | |
corr_val = abs(corr_matrix.iloc[i, j]) | |
if corr_val > 0.7: | |
high_corr.append({ | |
'var1': corr_matrix.columns[i], | |
'var2': corr_matrix.columns[j], | |
'correlation': round(corr_val, 3) | |
}) | |
correlations = high_corr[:5] # Top 5 correlations | |
return { | |
'shape': (rows, cols), | |
'columns': columns, | |
'numeric_cols': numeric_cols, | |
'categorical_cols': categorical_cols, | |
'datetime_cols': datetime_cols, | |
'missing_data': missing_data.to_dict(), | |
'missing_percentage': missing_percentage.to_dict(), | |
'numeric_stats': numeric_stats, | |
'categorical_info': categorical_info, | |
'correlations': correlations, | |
'data_quality': round((df.notna().sum().sum() / (rows * cols)) * 100, 1) | |
} | |
def create_insights_prompt(metadata: dict) -> str: | |
""" | |
Creates data insights prompt for Gemini. | |
Args: | |
metadata (dict): Dataset metadata | |
Returns: | |
str: Gemini prompt | |
""" | |
prompt = f""" | |
You are an expert data analyst and storyteller. Using the following dataset information, | |
predict what this dataset is about and tell a story about it. | |
DATASET INFORMATION: | |
- Size: {metadata['shape'][0]:,} rows, {metadata['shape'][1]} columns | |
- Columns: {', '.join(metadata['columns'])} | |
- Numeric columns: {', '.join(metadata['numeric_cols'])} | |
- Categorical columns: {', '.join(metadata['categorical_cols'])} | |
- Data quality: {metadata['data_quality']}% | |
CATEGORICAL VARIABLE DETAILS: | |
{metadata['categorical_info']} | |
HIGH CORRELATIONS: | |
{metadata['correlations']} | |
Please create a story in the following format: | |
# Dataset Overview | |
## What is this dataset about? | |
[Your prediction about the dataset] | |
## Which sector/domain does it belong to? | |
[Your sector analysis] | |
## Potential Use Cases | |
- [Use case 1] | |
- [Use case 2] | |
- [Use case 3] | |
## Interesting Findings | |
- [Finding 1] | |
- [Finding 2] | |
- [Finding 3] | |
## What Can We Do With This Data? | |
- [Potential analysis 1] | |
- [Potential analysis 2] | |
- [Potential analysis 3] | |
Make your story visual and engaging using emojis! | |
Keep it in English and make it professional yet accessible. | |
Use proper markdown formatting for headers and lists. | |
""" | |
return prompt | |
def generate_insights_with_gemini(prompt: str, api_key: str) -> str: | |
""" | |
Generates data insights using Gemini AI. | |
Args: | |
prompt (str): Prepared prompt for Gemini | |
api_key (str): Gemini API key | |
Returns: | |
str: Story generated by Gemini | |
""" | |
try: | |
genai.configure(api_key=api_key) | |
model = genai.GenerativeModel('gemini-1.5-flash') | |
response = model.generate_content(prompt) | |
return response.text | |
except Exception as e: | |
# Fallback story if Gemini API fails | |
return f""" | |
🔍 **DATA DISCOVERY STORY** | |
⚠️ Gemini API Error: {str(e)} | |
📊 **Fallback Analysis**: | |
This dataset appears to be a fascinating collection of information! | |
🎯 **Prediction**: Based on the structure, this could be business, e-commerce, or customer behavior data. | |
🏢 **Sector**: Likely used in retail, digital marketing, or analytics domain. | |
✨ **Potential Stories**: | |
• 🛒 Customer journey analysis | |
• 📈 Seasonal trends and patterns | |
• 👥 Customer segmentation | |
• 💡 Recommendation systems | |
• 🎯 Marketing campaign optimization | |
🔮 **What We Can Do**: | |
• Customer lifetime value prediction | |
• Churn prediction modeling | |
• Pricing strategy optimization | |
• Market basket analysis | |
• A/B testing insights | |
📊 The data quality looks promising for analysis! | |
""" | |
def create_basic_info_summary(metadata: dict) -> str: | |
"""Creates basic information summary text""" | |
summary = f""" | |
📋 **Dataset Overview** | |
📊 **Size**: {metadata['shape'][0]:,} rows × {metadata['shape'][1]} columns | |
🔢 **Data Types**: | |
• Numeric variables: {len(metadata['numeric_cols'])} | |
• Categorical variables: {len(metadata['categorical_cols'])} | |
• DateTime variables: {len(metadata['datetime_cols'])} | |
🎯 **Data Quality**: {metadata['data_quality']}% | |
📈 **Missing Data**: {sum(metadata['missing_data'].values())} total missing values | |
🔗 **High Correlations Found**: {len(metadata['correlations'])} pairs | |
""" | |
return summary | |
def generate_data_profiling(file_obj) -> tuple: | |
""" | |
Generates detailed data profiling report. | |
Args: | |
file_obj: Gradio file object | |
Returns: | |
missing_data_df (DataFrame): Missing data analysis | |
numeric_stats_df (DataFrame): Numeric statistics | |
categorical_stats_df (DataFrame): Categorical statistics | |
""" | |
if file_obj is None: | |
return None, None, None | |
try: | |
df = pd.read_csv(file_obj.name) | |
# Missing data analysis | |
missing_data = df.isnull().sum() | |
missing_pct = (missing_data / len(df) * 100).round(2) | |
missing_df = pd.DataFrame({ | |
'Column': missing_data.index, | |
'Missing Count': missing_data.values, | |
'Missing Percentage': missing_pct.values | |
}).sort_values('Missing Count', ascending=False) | |
# Numeric statistics | |
numeric_cols = df.select_dtypes(include=[np.number]).columns | |
numeric_stats_df = None | |
if len(numeric_cols) > 0: | |
numeric_stats_df = df[numeric_cols].describe().round(3).reset_index() | |
# Categorical statistics | |
cat_cols = df.select_dtypes(include=['object']).columns | |
categorical_stats = [] | |
for col in cat_cols: | |
categorical_stats.append({ | |
'Column': col, | |
'Unique Values': df[col].nunique(), | |
'Most Frequent': df[col].mode().iloc[0] if len(df[col].mode()) > 0 else 'N/A', | |
'Frequency': df[col].value_counts().iloc[0] if len(df[col].value_counts()) > 0 else 0 | |
}) | |
categorical_stats_df = pd.DataFrame(categorical_stats) if categorical_stats else None | |
return missing_df, numeric_stats_df, categorical_stats_df | |
except Exception as e: | |
error_df = pd.DataFrame({'Error': [f"Error in profiling: {str(e)}"]}) | |
return error_df, None, None | |
def create_smart_visualizations(file_obj) -> tuple: | |
""" | |
Creates smart visualizations. | |
Args: | |
file_obj: Gradio file object | |
Returns: | |
dtype_fig (Plot): Data type distribution chart | |
missing_fig (Plot): Missing data bar chart | |
correlation_fig (Plot): Correlation heatmap | |
distribution_fig (Plot): Variable distributions | |
""" | |
if file_obj is None: | |
return None, None, None, None | |
try: | |
df = pd.read_csv(file_obj.name) | |
# 1. Data type distribution | |
dtype_counts = df.dtypes.value_counts() | |
dtype_fig = px.pie( | |
values=dtype_counts.values, | |
names=[str(dtype) for dtype in dtype_counts.index], # Convert dtype objects to strings | |
title="🔍 Data Type Distribution" | |
) | |
dtype_fig.update_traces(textposition='inside', textinfo='percent+label') | |
# 2. Missing data heatmap | |
missing_data = df.isnull().sum() | |
missing_fig = px.bar( | |
x=missing_data.index, | |
y=missing_data.values, | |
title="🔴 Missing Data by Column", | |
labels={'x': 'Columns', 'y': 'Missing Count'} | |
) | |
missing_fig.update_xaxes(tickangle=45) | |
# 3. Correlation heatmap | |
numeric_cols = df.select_dtypes(include=[np.number]).columns | |
correlation_fig = None | |
if len(numeric_cols) > 1: | |
corr_matrix = df[numeric_cols].corr() | |
correlation_fig = px.imshow( | |
corr_matrix, | |
text_auto=True, | |
aspect="auto", | |
title="🔗 Correlation Matrix", | |
color_continuous_scale='RdBu' | |
) | |
# 4. Distribution plots for numeric variables | |
distribution_fig = None | |
if len(numeric_cols) > 0: | |
# Select first 4 numeric columns for distribution | |
cols_to_plot = numeric_cols[:4] | |
if len(cols_to_plot) == 1: | |
distribution_fig = px.histogram( | |
df, x=cols_to_plot[0], | |
title=f"📊 Distribution of {cols_to_plot[0]}" | |
) | |
else: | |
# Create subplots for multiple columns | |
fig = make_subplots( | |
rows=2, cols=2, | |
subplot_titles=[f"{col} Distribution" for col in cols_to_plot] | |
) | |
for i, col in enumerate(cols_to_plot): | |
row = (i // 2) + 1 | |
col_pos = (i % 2) + 1 | |
fig.add_trace( | |
go.Histogram(x=df[col].values, name=str(col), showlegend=False), # Convert to numpy array and string | |
row=row, col=col_pos | |
) | |
fig.update_layout(title="📊 Numeric Variable Distributions") | |
distribution_fig = fig | |
return dtype_fig, missing_fig, correlation_fig, distribution_fig | |
except Exception as e: | |
# Return error plot | |
error_fig = px.scatter(title=f"❌ Visualization Error: {str(e)}") | |
return error_fig, None, None, None | |
# Create Gradio interface | |
def create_gradio_interface(): | |
"""Creates main Gradio interface""" | |
with gr.Blocks(title="🚀 AI Data Explorer", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# 🚀 AutoEDA") | |
gr.Markdown("Upload your CSV file and get AI-powered analysis reports!") | |
with gr.Row(): | |
file_input = gr.File( | |
label="📁 Upload CSV File", | |
file_types=[".csv"] | |
) | |
with gr.Tabs(): | |
# Overview tab | |
with gr.Tab("🔍 Overview"): | |
gr.Markdown("### AI-Powered Data Insights") | |
with gr.Row(): | |
api_key_input = gr.Textbox( | |
label="🔑 Gemini API Key", | |
placeholder="Enter your Gemini API key here...", | |
type="password" | |
) | |
with gr.Row(): | |
overview_btn = gr.Button("🎯 Generate Story", variant="primary") | |
with gr.Row(): | |
with gr.Column(): | |
story_output = gr.Markdown( | |
label="📖 Data Insights", | |
value="" | |
) | |
with gr.Column(): | |
basic_info_output = gr.Markdown( | |
label="📋 Basic Information", | |
value="" | |
) | |
with gr.Row(): | |
quality_score = gr.Number( | |
label="🎯 Data Quality Score (%)", | |
precision=1 | |
) | |
overview_btn.click( | |
fn=analyze_dataset_overview, | |
inputs=[file_input, api_key_input], | |
outputs=[story_output, basic_info_output, quality_score] | |
) | |
# Profiling tab | |
with gr.Tab("📊 Data Profiling"): | |
gr.Markdown("### Automated Data Profiling") | |
with gr.Row(): | |
profiling_btn = gr.Button("🔍 Generate Profiling", variant="secondary") | |
with gr.Row(): | |
with gr.Column(): | |
missing_data_table = gr.Dataframe( | |
label="🔴 Missing Data Analysis", | |
interactive=False | |
) | |
with gr.Column(): | |
numeric_stats_table = gr.Dataframe( | |
label="🔢 Numeric Statistics", | |
interactive=False | |
) | |
with gr.Row(): | |
categorical_stats_table = gr.Dataframe( | |
label="📝 Categorical Statistics", | |
interactive=False | |
) | |
profiling_btn.click( | |
fn=generate_data_profiling, | |
inputs=[file_input], | |
outputs=[missing_data_table, numeric_stats_table, categorical_stats_table] | |
) | |
# Visualization tab | |
with gr.Tab("📈 Smart Visualizations"): | |
gr.Markdown("### Automated Data Visualizations") | |
with gr.Row(): | |
viz_btn = gr.Button("🎨 Create Visualizations", variant="secondary") | |
with gr.Row(): | |
with gr.Column(): | |
dtype_plot = gr.Plot(label="🔍 Data Types") | |
missing_plot = gr.Plot(label="🔴 Missing Data") | |
with gr.Column(): | |
correlation_plot = gr.Plot(label="🔗 Correlations") | |
distribution_plot = gr.Plot(label="📊 Distributions") | |
viz_btn.click( | |
fn=create_smart_visualizations, | |
inputs=[file_input], | |
outputs=[dtype_plot, missing_plot, correlation_plot, distribution_plot] | |
) | |
# Footer | |
gr.Markdown("---") | |
gr.Markdown("💡 **Tip**: Get your free Gemini API key from [Google AI Studio](https://aistudio.google.com/)") | |
return demo | |
# Main application | |
if __name__ == "__main__": | |
demo = create_gradio_interface() | |
demo.launch( | |
mcp_server=True | |
) |