ShaswatSingh commited on
Commit
da92c86
·
verified ·
1 Parent(s): 9532f6c

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. CSV_rag_.py +449 -0
  3. Readme.md +14 -0
  4. hotel_bookings.csv +3 -0
  5. requirements.txt +9 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ hotel_bookings.csv filter=lfs diff=lfs merge=lfs -text
CSV_rag_.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import os
4
+ import matplotlib.pyplot as plt
5
+ import io
6
+ from PIL import Image
7
+ import base64
8
+ import re
9
+ import numpy as np
10
+ from llama_index.llms.groq import Groq
11
+ from llama_index.core.query_pipeline import (
12
+ QueryPipeline as QP,
13
+ Link,
14
+ InputComponent,
15
+ )
16
+ from llama_index.experimental.query_engine.pandas import (
17
+ PandasInstructionParser,
18
+ )
19
+ from llama_index.core import PromptTemplate
20
+
21
+ # Example datasets
22
+ EXAMPLE_DATASETS = {
23
+ "Hotel Bookings": "hotel_bookings.csv",
24
+ }
25
+
26
+ def load_dataframe(file_path):
27
+ try:
28
+ if isinstance(file_path, str):
29
+ # If it's a URL or file path
30
+ df = pd.read_csv(file_path)
31
+ else:
32
+ # If it's an uploaded file
33
+ df = pd.read_csv(file_path.name)
34
+ return df, f"Successfully loaded dataset with {df.shape[0]} rows and {df.shape[1]} columns."
35
+ except Exception as e:
36
+ return None, f"Error loading dataset: {str(e)}"
37
+
38
+ def create_query_pipeline(df, api_key, model="llama-3.3-70b-versatile"):
39
+ # Create Groq LLM with the provided API key
40
+ try:
41
+ llm = Groq(model=model, api_key=api_key)
42
+ except Exception as e:
43
+ return None, f"Error initializing Groq LLM: {str(e)}"
44
+
45
+ instruction_str = (
46
+ "1. Convert the query to executable Python code using Pandas.\n"
47
+ "2. The final line of code should be a Python expression that can be called with the `eval()` function.\n"
48
+ "3. The code should represent a solution to the query.\n"
49
+ "4. PRINT ONLY THE EXPRESSION.\n"
50
+ "5. Do not quote the expression.\n"
51
+ )
52
+
53
+ pandas_prompt_str = (
54
+ "You are working with a pandas dataframe in Python.\n"
55
+ "The name of the dataframe is `df`.\n"
56
+ "This is the result of `print(df.head())`:\n"
57
+ "{df_str}\n\n"
58
+ "Follow these instructions:\n"
59
+ "{instruction_str}\n"
60
+ "Query: {query_str}\n\n"
61
+ "Expression:"
62
+ )
63
+
64
+ response_synthesis_prompt_str = (
65
+ "Given an input question, synthesize a response from the query results.\n"
66
+ "Query: {query_str}\n\n"
67
+ "Pandas Instructions (optional):\n{pandas_instructions}\n\n"
68
+ "Pandas Output: {pandas_output}\n\n"
69
+ "Response: "
70
+ )
71
+
72
+ pandas_prompt = PromptTemplate(pandas_prompt_str).partial_format(
73
+ instruction_str=instruction_str, df_str=df.head(5)
74
+ )
75
+ pandas_output_parser = PandasInstructionParser(df)
76
+ response_synthesis_prompt = PromptTemplate(response_synthesis_prompt_str)
77
+
78
+ qp = QP(
79
+ modules={
80
+ "input": InputComponent(),
81
+ "pandas_prompt": pandas_prompt,
82
+ "llm1": llm,
83
+ "pandas_output_parser": pandas_output_parser,
84
+ "response_synthesis_prompt": response_synthesis_prompt,
85
+ "llm2": llm,
86
+ },
87
+ verbose=True,
88
+ )
89
+ qp.add_chain(["input", "pandas_prompt", "llm1", "pandas_output_parser"])
90
+ qp.add_links(
91
+ [
92
+ Link("input", "response_synthesis_prompt", dest_key="query_str"),
93
+ Link(
94
+ "llm1", "response_synthesis_prompt", dest_key="pandas_instructions"
95
+ ),
96
+ Link(
97
+ "pandas_output_parser",
98
+ "response_synthesis_prompt",
99
+ dest_key="pandas_output",
100
+ ),
101
+ ]
102
+ )
103
+ qp.add_link("response_synthesis_prompt", "llm2")
104
+
105
+ return qp, "Query pipeline created successfully!"
106
+
107
+ def enhance_visualization(df, query):
108
+ """
109
+ Create an enhanced visualization based on the dataframe and query
110
+ This function attempts to create a better visualization with proper labels and formatting
111
+ """
112
+ try:
113
+ # Close any existing figures to avoid conflicts
114
+ plt.close('all')
115
+
116
+ # Create a new figure with larger size for better quality
117
+ plt.figure(figsize=(12, 8), dpi=100)
118
+
119
+ # Time-related visualization handling (for bookings over time, trends, etc.)
120
+ if any(term in query.lower() for term in ['trend', 'time', 'year', 'month', 'booking', 'reservation']):
121
+ # Try to detect date columns
122
+ date_cols = [col for col in df.columns if any(term in col.lower() for term in
123
+ ['date', 'year', 'month', 'time', 'arrival', 'reservation'])]
124
+
125
+ if 'arrival_date_year' in df.columns and 'arrival_date_month' in df.columns:
126
+ try:
127
+ # Create a year-month based visualization
128
+ # Convert month names to numbers for sorting
129
+ month_order = {
130
+ 'January': 1, 'February': 2, 'March': 3, 'April': 4, 'May': 5, 'June': 6,
131
+ 'July': 7, 'August': 8, 'September': 9, 'October': 10, 'November': 11, 'December': 12
132
+ }
133
+
134
+ # Count bookings by year and month
135
+ booking_counts = df.groupby(['arrival_date_year', 'arrival_date_month']).size().reset_index(name='count')
136
+
137
+ # Add month order for sorting
138
+ booking_counts['month_order'] = booking_counts['arrival_date_month'].map(month_order)
139
+ booking_counts = booking_counts.sort_values(['arrival_date_year', 'month_order'])
140
+
141
+ # Create pivot table for visualization
142
+ pivot_data = booking_counts.pivot(index='arrival_date_year', columns='arrival_date_month', values='count')
143
+
144
+ # Reorder columns by month
145
+ months = sorted(booking_counts['arrival_date_month'].unique(), key=lambda x: month_order.get(x, 13))
146
+
147
+ if len(months) > 0: # Check if the months list is not empty
148
+ pivot_data = pivot_data[months]
149
+
150
+ # Plot the data
151
+ ax = pivot_data.plot(kind='bar', figsize=(14, 8), width=0.8)
152
+
153
+ # Enhance the plot
154
+ plt.title('Bookings by Month and Year', fontsize=16)
155
+ plt.xlabel('Year', fontsize=14)
156
+ plt.ylabel('Number of Bookings', fontsize=14)
157
+ plt.legend(title='Month', fontsize=12)
158
+ plt.grid(axis='y', linestyle='--', alpha=0.7)
159
+ plt.tight_layout()
160
+
161
+ # Add value labels on top of bars
162
+ for container in ax.containers:
163
+ ax.bar_label(container, fontsize=9, fmt='%d')
164
+ else:
165
+ return None # No months data found
166
+ except Exception as e:
167
+ print(f"Error in time visualization: {str(e)}")
168
+ return None
169
+
170
+ elif len(date_cols) > 0 and any(col in df.columns for col in date_cols):
171
+ try:
172
+ # Handle other time-based visualizations
173
+ date_col = [col for col in date_cols if col in df.columns][0]
174
+ df_count = df.groupby(date_col).size().reset_index(name='count')
175
+
176
+ plt.bar(df_count[date_col], df_count['count'], color='steelblue')
177
+ plt.title(f'Distribution by {date_col}', fontsize=16)
178
+ plt.xlabel(date_col, fontsize=14)
179
+ plt.ylabel('Count', fontsize=14)
180
+ plt.grid(axis='y', linestyle='--', alpha=0.7)
181
+ plt.xticks(rotation=45)
182
+ plt.tight_layout()
183
+ except Exception as e:
184
+ print(f"Error in date column visualization: {str(e)}")
185
+ return None
186
+
187
+ else:
188
+ # Default time visualization if we can't find specific columns
189
+ return None # Let matplotlib handle it
190
+
191
+ # Distribution visualization (for questions about distributions)
192
+ elif any(term in query.lower() for term in ['distribution', 'histogram', 'spread']):
193
+ try:
194
+ numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
195
+ if len(numeric_cols) > 0:
196
+ # Choose a relevant column based on query or the first numeric column
197
+ target_col = None
198
+ for col in numeric_cols:
199
+ if col.lower() in query.lower():
200
+ target_col = col
201
+ break
202
+
203
+ if target_col is None and numeric_cols:
204
+ target_col = numeric_cols[0]
205
+
206
+ if target_col:
207
+ # Create histogram
208
+ plt.hist(df[target_col].dropna(), bins=30, color='steelblue', edgecolor='black', alpha=0.7)
209
+ plt.title(f'Distribution of {target_col}', fontsize=16)
210
+ plt.xlabel(target_col, fontsize=14)
211
+ plt.ylabel('Frequency', fontsize=14)
212
+ plt.grid(axis='y', linestyle='--', alpha=0.7)
213
+ plt.tight_layout()
214
+ else:
215
+ return None # Let matplotlib handle it
216
+ else:
217
+ return None # Let matplotlib handle it
218
+ except Exception as e:
219
+ print(f"Error in distribution visualization: {str(e)}")
220
+ return None
221
+
222
+ # Comparison visualization (for questions comparing categories)
223
+ elif any(term in query.lower() for term in ['compare', 'comparison', 'versus', 'vs', 'most', 'least']):
224
+ try:
225
+ categorical_cols = df.select_dtypes(include=['object']).columns.tolist()
226
+ if len(categorical_cols) > 0:
227
+ # Choose a relevant column based on query or the first categorical column
228
+ target_col = None
229
+ for col in categorical_cols:
230
+ if col.lower() in query.lower():
231
+ target_col = col
232
+ break
233
+
234
+ if target_col is None and categorical_cols:
235
+ target_col = categorical_cols[0]
236
+
237
+ if target_col:
238
+ # Get top categories by count
239
+ top_categories = df[target_col].value_counts().nlargest(10)
240
+
241
+ # Create bar chart
242
+ plt.bar(top_categories.index, top_categories.values, color='steelblue')
243
+ plt.title(f'Top Categories by {target_col}', fontsize=16)
244
+ plt.xlabel(target_col, fontsize=14)
245
+ plt.ylabel('Count', fontsize=14)
246
+ plt.grid(axis='y', linestyle='--', alpha=0.7)
247
+ plt.xticks(rotation=45, ha='right')
248
+ plt.tight_layout()
249
+ else:
250
+ return None # Let matplotlib handle it
251
+ else:
252
+ return None # Let matplotlib handle it
253
+ except Exception as e:
254
+ print(f"Error in comparison visualization: {str(e)}")
255
+ return None
256
+ else:
257
+ # For other types of queries, let the default matplotlib handle it
258
+ return None
259
+
260
+ # Save figure to buffer
261
+ buf = io.BytesIO()
262
+ plt.savefig(buf, format='png')
263
+ buf.seek(0)
264
+
265
+ # Create an image from the buffer
266
+ img = Image.open(buf)
267
+ plt.close('all') # Close the figure to free memory
268
+
269
+ return img
270
+ except Exception as e:
271
+ print(f"Error in enhance_visualization: {str(e)}")
272
+ plt.close('all') # Make sure to close any figures in case of error
273
+ return None
274
+
275
+ def process_query(query, api_key, df, model_choice):
276
+ if df is None:
277
+ return "Please load a dataset first.", None
278
+
279
+ if not api_key:
280
+ return "Please provide your Groq API key.", None
281
+
282
+ try:
283
+ # First, try to create an enhanced visualization based on the query
284
+ enhanced_img = enhance_visualization(df, query)
285
+
286
+ # Create and run the query pipeline
287
+ pipeline, message = create_query_pipeline(df, api_key, model_choice)
288
+ if pipeline is None:
289
+ return message, None
290
+
291
+ # Run the query
292
+ response = pipeline.run(query_str=query)
293
+
294
+ # If we already have an enhanced visualization, use it
295
+ if enhanced_img is not None:
296
+ return response.message.content, enhanced_img
297
+
298
+ # Otherwise check if any matplotlib figures were created by the query
299
+ figures = plt.get_fignums()
300
+
301
+ if figures:
302
+ try:
303
+ # Improve any existing figure if possible
304
+ fig = plt.figure(figures[0])
305
+ axes = fig.axes
306
+
307
+ if axes and len(axes) > 0: # Make sure axes list isn't empty
308
+ ax = axes[0]
309
+ # Add grid lines
310
+ ax.grid(axis='y', linestyle='--', alpha=0.7)
311
+ # Enhance title and labels if they exist
312
+ if ax.get_title():
313
+ ax.set_title(ax.get_title(), fontsize=16)
314
+ if ax.get_xlabel():
315
+ ax.set_xlabel(ax.get_xlabel(), fontsize=14)
316
+ if ax.get_ylabel():
317
+ ax.set_ylabel(ax.get_ylabel(), fontsize=14)
318
+ # Handle legend if it exists
319
+ if ax.get_legend():
320
+ ax.legend(fontsize=12)
321
+ fig.tight_layout()
322
+
323
+ # Save the figure to a bytes buffer
324
+ buf = io.BytesIO()
325
+ plt.savefig(buf, format='png', dpi=100)
326
+ buf.seek(0)
327
+
328
+ # Create an image from the buffer
329
+ img = Image.open(buf)
330
+ plt.close('all') # Close the figure to free memory
331
+
332
+ return response.message.content, img
333
+ except Exception as e:
334
+ plt.close('all')
335
+ # Log the error but continue without crashing
336
+ print(f"Visualization error: {str(e)}")
337
+ return response.message.content, None
338
+ else:
339
+ # No visualization was generated
340
+ return response.message.content, None
341
+
342
+ except Exception as e:
343
+ plt.close('all') # Make sure to close any figures in case of error
344
+ return f"Error processing query: {str(e)}", None
345
+
346
+ def handle_example_selection(example_name):
347
+ if example_name in EXAMPLE_DATASETS:
348
+ file_path = EXAMPLE_DATASETS[example_name]
349
+ df, message = load_dataframe(file_path)
350
+ return df, message, gr.update(value=f"Dataset preview:\n{df.head().to_string()}")
351
+ return None, "Please select a valid example dataset.", gr.update(value="")
352
+
353
+ def handle_file_upload(file):
354
+ if file is not None:
355
+ df, message = load_dataframe(file)
356
+ return df, message, gr.update(value=f"Dataset preview:\n{df.head().to_string()}")
357
+ return None, "No file uploaded.", gr.update(value="")
358
+
359
+ # Create Gradio interface
360
+ with gr.Blocks(title="Pandas Data Analysis with Groq LLM") as app:
361
+ gr.Markdown("# Pandas Data Analysis with Groq LLM")
362
+ gr.Markdown("Upload your CSV data or choose an example dataset, then ask questions about it.")
363
+
364
+ # State variables
365
+ df_state = gr.State(value=None)
366
+
367
+ with gr.Row():
368
+ with gr.Column(scale=1):
369
+ with gr.Group():
370
+ gr.Markdown("### Data Selection")
371
+ with gr.Tab("Upload Data"):
372
+ file_input = gr.File(label="Upload CSV File", file_types=[".csv"])
373
+ upload_button = gr.Button("Load Uploaded Data")
374
+
375
+ with gr.Tab("Example Datasets"):
376
+ example_dropdown = gr.Dropdown(
377
+ choices=list(EXAMPLE_DATASETS.keys()),
378
+ label="Select Example Dataset"
379
+ )
380
+ example_button = gr.Button("Load Example Dataset")
381
+
382
+ data_status = gr.Textbox(label="Data Loading Status", interactive=False)
383
+
384
+ with gr.Group():
385
+ gr.Markdown("### Groq API Configuration")
386
+ api_key = gr.Textbox(
387
+ label="Enter your Groq API Key",
388
+ placeholder="gsk_...",
389
+ type="password"
390
+ )
391
+ model_choice = gr.Dropdown(
392
+ choices=["llama-3.3-70b-versatile", "mixtral-8x7b-32768", "gemma-7b-it"],
393
+ value="llama-3.3-70b-versatile",
394
+ label="Select Groq Model"
395
+ )
396
+
397
+ with gr.Column(scale=1):
398
+ data_preview = gr.Textbox(label="Dataset Preview", interactive=False, lines=10)
399
+ query_input = gr.Textbox(
400
+ label="Ask a question about your data",
401
+ placeholder="e.g., What is the trend of monthly bookings over time?",
402
+ lines=2
403
+ )
404
+ query_button = gr.Button("Submit Query")
405
+
406
+ # Output display with tabs for text and visualization
407
+ with gr.Tabs():
408
+ with gr.TabItem("Text Response"):
409
+ response_output = gr.Textbox(label="Response", interactive=False, lines=10)
410
+ with gr.TabItem("Visualization"):
411
+ image_output = gr.Image(label="Data Visualization", interactive=False)
412
+
413
+ # Handle events
414
+ upload_button.click(
415
+ handle_file_upload,
416
+ inputs=[file_input],
417
+ outputs=[df_state, data_status, data_preview]
418
+ )
419
+
420
+ example_button.click(
421
+ handle_example_selection,
422
+ inputs=[example_dropdown],
423
+ outputs=[df_state, data_status, data_preview]
424
+ )
425
+
426
+ query_button.click(
427
+ process_query,
428
+ inputs=[query_input, api_key, df_state, model_choice],
429
+ outputs=[response_output, image_output]
430
+ )
431
+
432
+ gr.Markdown("""
433
+ ### Instructions
434
+ 1. Upload your CSV file or select an example dataset
435
+ 2. Enter your Groq API key (get one at [https://console.groq.com](https://console.groq.com))
436
+ 3. Ask questions about your data in natural language
437
+ 4. Get AI-powered insights and visualizations based on your data
438
+
439
+ ### Example Questions
440
+ - What is the trend of monthly bookings over time?
441
+ - What's the distribution of stay duration?
442
+ - Which country has the most bookings?
443
+ - Is there a correlation between lead time and cancellations?
444
+ - Show me bookings by month and year
445
+ """)
446
+
447
+ # Launch the app
448
+ if __name__ == "__main__":
449
+ app.launch()
Readme.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: CSVision
3
+ emoji: 🚀
4
+ colorFrom: white
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 5.23.3
8
+ app_file: CSV_rag_.py
9
+ pinned: false
10
+ ---
11
+
12
+ # My Hugging Face Space
13
+
14
+ Welcome to my Hugging Face Space! 🎉
hotel_bookings.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c2ae42a7353905ea136e5c2287f17c92c5435826598bfbb8491c6f0c7b1fc06
3
+ size 16855599
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ pandas
3
+ numpy
4
+ matplotlib
5
+ pillow
6
+ base64
7
+ llama-index-llms-groq
8
+ llama-index-experimental
9
+ llama-index