File size: 7,179 Bytes
5f946b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import gradio as gr
import os
import sys
import pandas as pd
import sqlite3
from pathlib import Path
import matplotlib.pyplot as plt
import re

# For Hugging Face Spaces, set project root to current directory
PROJECT_ROOT = Path(__file__).parent.resolve()
sys.path.append(str(PROJECT_ROOT))

# Import model loading and utility functions
from code.train_sqlgen_t5_local import load_model as load_sql_model, generate_sql, get_schema_from_csv
from code.train_intent_classifier_local import load_model as load_intent_model, classify_intent

# Load models
sql_model, sql_tokenizer, device = load_sql_model()
intent_model, intent_tokenizer, device, label_mapping = load_intent_model()

# Path to the built-in data file in the data folder
DATA_FILE = str(PROJECT_ROOT / "data" / "testing_sql_data.csv")

# Verify data file exists
if not os.path.exists(DATA_FILE):
    raise FileNotFoundError(f"Data file not found at {DATA_FILE}. Please ensure testing_sql_data.csv exists in the data folder.")

def process_query(question, chart_type="auto"):
    try:
        # Generate schema from CSV
        schema = get_schema_from_csv(DATA_FILE)
        # Generate SQL
        sql_query = generate_sql(question, schema, sql_model, sql_tokenizer, device)
        # --- Fix: Table and column name replacements ---
        sql_query = re.sub(r'(FROM|JOIN)\s+\w+', r'\1 data', sql_query, flags=re.IGNORECASE)
        sql_query = re.sub(r'(FROM|JOIN)\s+"[^"]+"', r'\1 data', sql_query, flags=re.IGNORECASE)
        sql_query = re.sub(r'(FROM|JOIN)\s+\'[^"]+\'', r'\1 data', sql_query, flags=re.IGNORECASE)
        sql_query = sql_query.replace('product_price', 'total_price')
        sql_query = sql_query.replace('store_name', 'store_id')
        sql_query = sql_query.replace('sales_method', 'date')
        sql_query = re.sub(r'\bsales\b', 'total_price', sql_query)
        # --- End fix ---
        # Classify intent
        intent = classify_intent(question, intent_model, intent_tokenizer, device, label_mapping)
        # Execute SQL on the CSV data
        df = pd.read_csv(DATA_FILE)
        conn = sqlite3.connect(":memory:")
        df.to_sql("data", conn, index=False, if_exists="replace")
        result_df = pd.read_sql_query(sql_query, conn)
        conn.close()
        # Defensive check for result_df columns
        if result_df.empty or len(result_df.columns) < 2:
            chart_path = None
            insights = "No results or not enough columns to display chart/insights."
            return result_df, intent, sql_query, chart_path, insights
        # Generate chart
        chart_path = os.path.join(PROJECT_ROOT, "chart.png")
        if not result_df.empty:
            plt.figure(figsize=(10, 6))
            if chart_type == "auto":
                if intent == "trend":
                    chart_type = "line"
                elif intent == "comparison":
                    chart_type = "bar"
                else:
                    chart_type = "bar"
            if chart_type == "bar":
                result_df.plot(kind="bar", x=result_df.columns[0], y=result_df.columns[1])
            elif chart_type == "line":
                result_df.plot(kind="line", x=result_df.columns[0], y=result_df.columns[1], marker='o')
            elif chart_type == "pie":
                result_df.plot(kind="pie", y=result_df.columns[1], labels=result_df[result_df.columns[0]])
            plt.title(question)
            plt.tight_layout()
            plt.savefig(chart_path)
            plt.close()
        else:
            chart_path = None
        # Generate insights
        insights = generate_insights(result_df, intent, question)
        return result_df, intent, sql_query, chart_path, insights
    except Exception as e:
        return None, "Error", str(e), None, f"Error: {str(e)}"

def generate_insights(result_df, intent, question):
    if result_df is None or result_df.empty or len(result_df.columns) < 2:
        return "No data available for insights."
    insights = []
    if intent == "summary":
        try:
            total = result_df[result_df.columns[1]].sum()
            insights.append(f"Total {result_df.columns[1]}: {total:,.2f}")
        except Exception:
            pass
    elif intent == "comparison":
        if len(result_df) >= 2:
            try:
                highest = result_df.iloc[0]
                lowest = result_df.iloc[-1]
                diff = ((highest.iloc[1] / lowest.iloc[1] - 1) * 100)
                insights.append(f"{highest.iloc[0]} is {diff:.1f}% higher than {lowest.iloc[0]}")
            except Exception:
                pass
    elif intent == "trend":
        if len(result_df) >= 2:
            try:
                first = result_df.iloc[0][result_df.columns[1]]
                last = result_df.iloc[-1][result_df.columns[1]]
                change = ((last / first - 1) * 100)
                insights.append(f"Overall change: {change:+.1f}%")
            except Exception:
                pass
    insights.append(f"Analysis covers {len(result_df)} records")
    if "category" in result_df.columns:
        insights.append(f"Number of categories: {result_df['category'].nunique()}")
    return "\n".join(f"• {insight}" for insight in insights)

# Clickable FAQs (6 only)
faqs = [
    "What are the top 5 products by quantity sold?",
    "What is the total sales amount for each category?",
    "Which store had the highest total sales?",
    "What are the most popular payment methods?",
    "What is the sales trend over time?",
    "What is the average transaction value?"
]

def fill_question(faq):
    return gr.update(value=faq)

with gr.Blocks(title="RetailGenie - Natural Language to SQL") as demo:
    gr.Markdown("""
    # RetailGenie - Natural Language to SQL
    Ask questions in natural language to generate SQL queries and visualizations. Using retail dataset with product sales information.
    """)
    with gr.Row():
        with gr.Column(scale=1):
            question = gr.Textbox(
                label="Enter your question",
                placeholder="What is the total sales amount for each product category?"
            )
            faq_radio = gr.Radio(faqs, label="FAQs (click to autofill)", interactive=True)
            faq_radio.change(fn=fill_question, inputs=faq_radio, outputs=question)
            chart_type = gr.Radio(
                ["auto", "bar", "line", "pie"],
                label="Chart Type",
                value="auto"
            )
            submit_btn = gr.Button("Generate", variant="primary")
        with gr.Column(scale=2):
            with gr.Accordion("SQL and Intent Details", open=False):
                intent_output = gr.Textbox(label="Predicted Intent")
                sql_output = gr.Textbox(label="Generated SQL", lines=3)
            results_df = gr.DataFrame(label="Query Results")
            chart_output = gr.Image(label="Chart")
            insights_output = gr.Textbox(label="Insights", lines=5)
    submit_btn.click(
        fn=process_query,
        inputs=[question, chart_type],
        outputs=[results_df, intent_output, sql_output, chart_output, insights_output]
    )

if __name__ == "__main__":
    demo.launch()