UAP-Data-Analysis-Tool / analyzing.py
Ashoka74's picture
Update analyzing.py
7ba2d29 verified
import streamlit as st
#import cudf.pandas
#cudf.pandas.install()
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from uap_analyzer import UAPParser, UAPAnalyzer, UAPVisualizer
# import ChartGen
# from ChartGen import ChartGPT
from Levenshtein import distance
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from stqdm import stqdm
stqdm.pandas()
import streamlit.components.v1 as components
from dateutil import parser
from sentence_transformers import SentenceTransformer
import torch
import squarify
import matplotlib.colors as mcolors
import textwrap
import datamapplot
st.set_option('deprecation.showPyplotGlobalUse', False)
from pandas.api.types import (
is_categorical_dtype,
is_datetime64_any_dtype,
is_numeric_dtype,
is_object_dtype,
)
def load_data(file_path, key='df'):
return pd.read_hdf(file_path, key=key)
def gemini_query(question, selected_data, gemini_key):
if question == "":
question = "Summarize the following data in relevant bullet points"
import pathlib
import textwrap
import google.generativeai as genai
from IPython.display import display
from IPython.display import Markdown
def to_markdown(text):
text = text.replace('•', ' *')
return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))
# selected_data is a list
# remove empty
filtered = [str(x) for x in selected_data if str(x) != '' and x is not None]
# make a string
context = '\n'.join(filtered)
genai.configure(api_key=gemini_key)
query_model = genai.GenerativeModel('models/gemini-1.5-pro-latest')
response = query_model.generate_content([f"{question}\n Answer based on this context: {context}\n\n"])
return(response.text)
def plot_treemap(df, column, top_n=32):
# Get the value counts and the top N labels
value_counts = df[column].value_counts()
top_labels = value_counts.iloc[:top_n].index
# Use np.where to replace all values not in the top N with 'Other'
revised_column = f'{column}_revised'
df[revised_column] = np.where(df[column].isin(top_labels), df[column], 'Other')
# Get the value counts including the 'Other' category
sizes = df[revised_column].value_counts().values
labels = df[revised_column].value_counts().index
# Get a gradient of colors
# colors = list(mcolors.TABLEAU_COLORS.values())
n_colors = len(sizes)
colors = plt.cm.Oranges(np.linspace(0.3, 0.9, n_colors))[::-1]
# Get % of each category
percents = sizes / sizes.sum()
# Prepare labels with percentages
labels = [f'{label}\n {percent:.1%}' for label, percent in zip(labels, percents)]
fig, ax = plt.subplots(figsize=(20, 12))
# Plot the treemap
squarify.plot(sizes=sizes, label=labels, alpha=0.7, pad=True, color=colors, text_kwargs={'fontsize': 10})
ax = plt.gca()
# Iterate over text elements and rectangles (patches) in the axes for color adjustment
for text, rect in zip(ax.texts, ax.patches):
background_color = rect.get_facecolor()
r, g, b, _ = mcolors.to_rgba(background_color)
brightness = np.average([r, g, b])
text.set_color('white' if brightness < 0.5 else 'black')
# Adjust font size based on rectangle's area and wrap long text
coef = 0.8
font_size = np.sqrt(rect.get_width() * rect.get_height()) * coef
text.set_fontsize(font_size)
wrapped_text = textwrap.fill(text.get_text(), width=20)
text.set_text(wrapped_text)
plt.axis('off')
plt.gca().invert_yaxis()
plt.gcf().set_size_inches(20, 12)
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
return fig
def plot_hist(df, column, bins=10, kde=True):
fig, ax = plt.subplots(figsize=(12, 6))
sns.histplot(data=df, x=column, kde=True, bins=bins,color='orange')
# set the ticks and frame in orange
ax.spines['bottom'].set_color('orange')
ax.spines['top'].set_color('orange')
ax.spines['right'].set_color('orange')
ax.spines['left'].set_color('orange')
ax.xaxis.label.set_color('orange')
ax.yaxis.label.set_color('orange')
ax.tick_params(axis='x', colors='orange')
ax.tick_params(axis='y', colors='orange')
ax.title.set_color('orange')
# Set transparent background
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
return fig
def plot_line(df, x_column, y_columns, figsize=(12, 10), color='orange', title=None, rolling_mean_value=2):
import matplotlib.cm as cm
# Sort the dataframe by the date column
df = df.sort_values(by=x_column)
# Calculate rolling mean for each y_column
if rolling_mean_value:
df[y_columns] = df[y_columns].rolling(len(df) // rolling_mean_value).mean()
# Create the plot
fig, ax = plt.subplots(figsize=figsize)
colors = cm.Oranges(np.linspace(0.2, 1, len(y_columns)))
# Plot each y_column as a separate line with a different color
for i, y_column in enumerate(y_columns):
df.plot(x=x_column, y=y_column, ax=ax, color=colors[i], label=y_column, linewidth=.5)
# Rotate x-axis labels
ax.set_xticklabels(ax.get_xticklabels(), rotation=30, ha='right')
# Format x_column as date if it is
if np.issubdtype(df[x_column].dtype, np.datetime64) or np.issubdtype(df[x_column].dtype, np.timedelta64):
df[x_column] = pd.to_datetime(df[x_column]).dt.date
# Set title, labels, and legend
ax.set_title(title or f'{", ".join(y_columns)} over {x_column}', color=color, fontweight='bold')
ax.set_xlabel(x_column, color=color)
ax.set_ylabel(', '.join(y_columns), color=color)
ax.spines['bottom'].set_color('orange')
ax.spines['top'].set_color('orange')
ax.spines['right'].set_color('orange')
ax.spines['left'].set_color('orange')
ax.xaxis.label.set_color('orange')
ax.yaxis.label.set_color('orange')
ax.tick_params(axis='x', colors='orange')
ax.tick_params(axis='y', colors='orange')
ax.title.set_color('orange')
ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange')
# Remove background
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
return fig
def plot_bar(df, x_column, y_column, figsize=(12, 10), color='orange', title=None, rotation=45):
fig, ax = plt.subplots(figsize=figsize)
sns.barplot(data=df, x=x_column, y=y_column, color=color, ax=ax)
ax.set_title(title if title else f'{y_column} by {x_column}', color=color, fontweight='bold')
ax.set_xlabel(x_column, color=color)
ax.set_ylabel(y_column, color=color)
ax.tick_params(axis='x', colors=color)
ax.tick_params(axis='y', colors=color)
plt.xticks(rotation=rotation)
# Remove background
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
ax.spines['bottom'].set_color('orange')
ax.spines['top'].set_color('orange')
ax.spines['right'].set_color('orange')
ax.spines['left'].set_color('orange')
ax.xaxis.label.set_color('orange')
ax.yaxis.label.set_color('orange')
ax.tick_params(axis='x', colors='orange')
ax.tick_params(axis='y', colors='orange')
ax.title.set_color('orange')
ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange')
return fig
def plot_grouped_bar(df, x_columns, y_column, figsize=(12, 10), colors=None, title=None):
fig, ax = plt.subplots(figsize=figsize)
width = 0.8 / len(x_columns) # the width of the bars
x = np.arange(len(df)) # the label locations
for i, x_column in enumerate(x_columns):
sns.barplot(data=df, x=x, y=y_column, color=colors[i] if colors else None, ax=ax, width=width, label=x_column)
x += width # add the width of the bar to the x position for the next bar
ax.set_title(title if title else f'{y_column} by {", ".join(x_columns)}', color='orange', fontweight='bold')
ax.set_xlabel('Groups', color='orange')
ax.set_ylabel(y_column, color='orange')
ax.set_xticks(x - width * len(x_columns) / 2)
ax.set_xticklabels(df.index)
ax.tick_params(axis='x', colors='orange')
ax.tick_params(axis='y', colors='orange')
# Remove background
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
ax.spines['bottom'].set_color('orange')
ax.spines['top'].set_color('orange')
ax.spines['right'].set_color('orange')
ax.spines['left'].set_color('orange')
ax.xaxis.label.set_color('orange')
ax.yaxis.label.set_color('orange')
ax.title.set_color('orange')
ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange')
return fig
def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame:
"""
Adds a UI on top of a dataframe to let viewers filter columns
Args:
df (pd.DataFrame): Original dataframe
Returns:
pd.DataFrame: Filtered dataframe
"""
title_font = "Arial"
body_font = "Arial"
title_size = 32
colors = ["red", "green", "blue"]
interpretation = False
extract_docx = False
title = "My Chart"
regex = ".*"
img_path = 'default_image.png'
#try:
# modify = st.checkbox("Add filters on raw data")
#except:
# try:
# modify = st.checkbox("Add filters on processed data")
# except:
# try:
# modify = st.checkbox("Add filters on parsed data")
# except:
# pass
#if not modify:
# return df
df_ = df.copy()
# Try to convert datetimes into a standard format (datetime, no timezone)
#modification_container = st.container()
#with modification_container:
try:
to_filter_columns = st.multiselect("Filter dataframe on", df_.columns)
except:
try:
to_filter_columns = st.multiselect("Filter dataframe", df_.columns)
except:
try:
to_filter_columns = st.multiselect("Filter the dataframe on", df_.columns)
except:
pass
date_column = None
filtered_columns = []
for column in to_filter_columns:
left, right = st.columns((1, 20))
# Treat columns with < 200 unique values as categorical if not date or numeric
if is_categorical_dtype(df_[column]) or (df_[column].nunique() < 120 and not is_datetime64_any_dtype(df_[column]) and not is_numeric_dtype(df_[column])):
user_cat_input = right.multiselect(
f"Values for {column}",
df_[column].value_counts().index.tolist(),
default=list(df_[column].value_counts().index)
)
df_ = df_[df_[column].isin(user_cat_input)]
filtered_columns.append(column)
with st.status(f"Category Distribution: {column}", expanded=False) as stat:
st.pyplot(plot_treemap(df_, column))
elif is_numeric_dtype(df_[column]):
_min = float(df_[column].min())
_max = float(df_[column].max())
step = (_max - _min) / 100
user_num_input = right.slider(
f"Values for {column}",
min_value=_min,
max_value=_max,
value=(_min, _max),
step=step,
)
df_ = df_[df_[column].between(*user_num_input)]
filtered_columns.append(column)
# Chart_GPT = ChartGPT(df_, title_font, body_font, title_size,
# colors, interpretation, extract_docx, img_path)
with st.status(f"Numerical Distribution: {column}", expanded=False) as stat_:
st.pyplot(plot_hist(df_, column, bins=int(round(len(df_[column].unique())-1)/2)))
elif is_object_dtype(df_[column]):
try:
df_[column] = pd.to_datetime(df_[column], infer_datetime_format=True, errors='coerce')
except Exception:
try:
df_[column] = df_[column].apply(parser.parse)
except Exception:
pass
if is_datetime64_any_dtype(df_[column]):
df_[column] = df_[column].dt.tz_localize(None)
min_date = df_[column].min().date()
max_date = df_[column].max().date()
user_date_input = right.date_input(
f"Values for {column}",
value=(min_date, max_date),
min_value=min_date,
max_value=max_date,
)
# if len(user_date_input) == 2:
# start_date, end_date = user_date_input
# df_ = df_.loc[df_[column].dt.date.between(start_date, end_date)]
if len(user_date_input) == 2:
user_date_input = tuple(map(pd.to_datetime, user_date_input))
start_date, end_date = user_date_input
# Determine the most appropriate time unit for plot
time_units = {
'year': df_[column].dt.year,
'month': df_[column].dt.to_period('M'),
'day': df_[column].dt.date
}
unique_counts = {unit: col.nunique() for unit, col in time_units.items()}
closest_to_36 = min(unique_counts, key=lambda k: abs(unique_counts[k] - 36))
# Group by the most appropriate time unit and count occurrences
grouped = df_.groupby(time_units[closest_to_36]).size().reset_index(name='count')
grouped.columns = [column, 'count']
# Create a complete date range
if closest_to_36 == 'year':
date_range = pd.date_range(start=f"{start_date.year}-01-01", end=f"{end_date.year}-12-31", freq='YS')
elif closest_to_36 == 'month':
date_range = pd.date_range(start=start_date.replace(day=1), end=end_date + pd.offsets.MonthEnd(0), freq='MS')
else: # day
date_range = pd.date_range(start=start_date, end=end_date, freq='D')
# Create a DataFrame with the complete date range
complete_range = pd.DataFrame({column: date_range})
# Convert the date column to the appropriate format based on closest_to_36
if closest_to_36 == 'year':
complete_range[column] = complete_range[column].dt.year
elif closest_to_36 == 'month':
complete_range[column] = complete_range[column].dt.to_period('M')
# Merge the complete range with the grouped data
final_data = pd.merge(complete_range, grouped, on=column, how='left').fillna(0)
with st.status(f"Date Distributions: {column}", expanded=False) as stat:
try:
st.pyplot(plot_bar(final_data, column, 'count'))
except Exception as e:
st.error(f"Error plotting bar chart: {e}")
df_ = df_.loc[df_[column].between(start_date, end_date)]
date_column = column
if date_column and filtered_columns:
numeric_columns = [col for col in filtered_columns if is_numeric_dtype(df_[col])]
if numeric_columns:
fig = plot_line(df_, date_column, numeric_columns)
#st.pyplot(fig)
# now to deal with categorical columns
categorical_columns = [col for col in filtered_columns if is_categorical_dtype(df_[col])]
if categorical_columns:
fig2 = plot_bar(df_, date_column, categorical_columns[0])
#st.pyplot(fig2)
with st.status(f"Date Distribution: {column}", expanded=False) as stat:
try:
st.pyplot(fig)
except Exception as e:
st.error(f"Error plotting line chart: {e}")
pass
try:
st.pyplot(fig2)
except Exception as e:
st.error(f"Error plotting bar chart: {e}")
else:
user_text_input = right.text_input(
f"Substring or regex in {column}",
)
if user_text_input:
df_ = df_[df_[column].astype(str).str.contains(user_text_input)]
# write len of df after filtering with % of original
st.write(f"{len(df_)} rows ({len(df_) / len(df) * 100:.2f}%)")
return df_
def merge_clusters(df, column):
cluster_terms_ = df.__dict__['cluster_terms']
cluster_labels_ = df.__dict__['cluster_labels']
label_name_map = {label: cluster_terms_[label] for label in set(cluster_labels_)}
merge_map = {}
# Iterate over term pairs and decide on merging based on the distance
for idx, term1 in enumerate(cluster_terms_):
for jdx, term2 in enumerate(cluster_terms_):
if idx < jdx and distance(term1, term2) <= 3: # Adjust threshold as needed
# Decide to merge labels corresponding to jdx into labels corresponding to idx
# Find labels corresponding to jdx and idx
labels_to_merge = [label for label, term_index in enumerate(cluster_labels_) if term_index == jdx]
for label in labels_to_merge:
merge_map[label] = idx # Map the label to use the term index of term1
# Update the analyzer with the merged numeric labels
updated_cluster_labels_ = [merge_map[label] if label in merge_map else label for label in cluster_labels_]
df.__dict__['cluster_labels'] = updated_cluster_labels_
# Optional: Update string labels to reflect merged labels
updated_string_labels = [cluster_terms_[label] for label in updated_cluster_labels_]
df.__dict__['string_labels'] = updated_string_labels
return updated_string_labels
def analyze_and_predict(data, analyzers, col_names, clusters):
visualizer = UAPVisualizer()
new_data = pd.DataFrame()
for i, column in enumerate(col_names):
#new_data[f'Analyzer_{column}'] = analyzers[column].__dict__['cluster_labels']
new_data[f'Analyzer_{column}'] = clusters[column]
data[f'Analyzer_{column}'] = clusters[column]
#data[f'Analyzer_{column}'] = analyzer.__dict__['cluster_labels']
print(f"Cluster terms extracted for {column}")
for col in data.columns:
if 'Analyzer' in col:
data[col] = data[col].astype('category')
new_data = new_data.fillna('null').astype('category')
data_nums = new_data.apply(lambda x: x.cat.codes)
for col in data_nums.columns:
try:
categories = new_data[col].cat.categories
x_train, x_test, y_train, y_test = train_test_split(data_nums.drop(columns=[col]), data_nums[col], test_size=0.2, random_state=42)
bst, accuracy, preds = visualizer.train_xgboost(x_train, y_train, x_test, y_test, len(categories))
fig = visualizer.plot_results(new_data, bst, x_test, y_test, preds, categories, accuracy, col)
with st.status(f"Charts Analyses: {col}", expanded=True) as status:
st.pyplot(fig)
status.update(label=f"Chart Processed: {col}", expanded=False)
except Exception as e:
print(f"Error processing {col}: {e}")
continue
return new_data, data
from config import API_KEY, GEMINI_KEY, FORMAT_LONG
with torch.no_grad():
torch.cuda.empty_cache()
#st.set_page_config(
# page_title="UAP ANALYSIS",
# page_icon=":alien:",
# layout="wide",
# initial_sidebar_state="expanded",
#)
st.title('UAP Analysis Dashboard')
# Initialize session state
if 'analyzers' not in st.session_state:
st.session_state['analyzers'] = []
if 'col_names' not in st.session_state:
st.session_state['col_names'] = []
if 'clusters' not in st.session_state:
st.session_state['clusters'] = {}
if 'new_data' not in st.session_state:
st.session_state['new_data'] = pd.DataFrame()
if 'dataset' not in st.session_state:
st.session_state['dataset'] = pd.DataFrame()
if 'data_processed' not in st.session_state:
st.session_state['data_processed'] = False
if 'stage' not in st.session_state:
st.session_state['stage'] = 0
if 'filtered_data' not in st.session_state:
st.session_state['filtered_data'] = None
if 'gemini_answer' not in st.session_state:
st.session_state['gemini_answer'] = None
if 'parsed_responses' not in st.session_state:
st.session_state['parsed_responses'] = None
# Load dataset
data_path = 'parsed_files_distance_embeds.h5'
my_dataset = st.file_uploader("Upload Parsed DataFrame", type=["csv", "xlsx"])
if my_dataset is not None:
if parsed: # save space by cleaning default dataset
parsed = None
try:
if my_dataset.type == "text/csv":
data = pd.read_csv(my_dataset)
elif my_dataset.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
data = pd.read_excel(my_dataset)
else:
st.error("Unsupported file type. Please upload a CSV, Excel or HD5 file.")
st.stop()
parser = filter_dataframe(data)
st.session_state['parsed_responses'] = parser
st.dataframe(parser)
st.success(f"Successfully loaded and displayed data from {my_dataset.name}")
except Exception as e:
st.error(f"An error occurred while reading the file: {e}")
else:
parsed = load_data(data_path).drop(columns=['embeddings'])
parsed_responses = filter_dataframe(parsed)
st.session_state['parsed_responses'] = parsed_responses
st.dataframe(parsed_responses)
col1, col2 = st.columns(2)
with col1:
col_parsed = st.selectbox("Which column do you want to query?", st.session_state['parsed_responses'].columns)
with col2:
GEMINI_KEY = st.text_input('Gemini API Key', value=GEMINI_KEY, type='password', help="Enter your Gemini API key")
if col_parsed and GEMINI_KEY:
selected_column_data = st.session_state['parsed_responses'][col_parsed].tolist()
question = st.text_input("Ask a question or leave empty for summarization")
if st.button("Generate Query") and selected_column_data:
st.write(gemini_query(question, selected_column_data, GEMINI_KEY))
st.session_state['stage'] = 1
if st.session_state['stage'] > 0 :
with st.form(border=True, key='Select Columns for Analysis'):
columns_to_analyze = st.multiselect(
label='Select columns to analyze',
options=st.session_state['parsed_responses'].columns
)
if st.form_submit_button("Process Data"):
if columns_to_analyze:
analyzers = []
col_names = []
clusters = {}
for column in columns_to_analyze:
with torch.no_grad():
with st.status(f"Processing {column}", expanded=True) as status:
analyzer = UAPAnalyzer(st.session_state['parsed_responses'], column)
st.write(f"Processing {column}...")
analyzer.preprocess_data(top_n=32)
st.write("Reducing dimensionality...")
analyzer.reduce_dimensionality(method='UMAP', n_components=2, n_neighbors=15, min_dist=0.1)
st.write("Clustering data...")
analyzer.cluster_data(method='HDBSCAN', min_cluster_size=15)
analyzer.get_tf_idf_clusters(top_n=3)
st.write("Naming clusters...")
analyzers.append(analyzer)
col_names.append(column)
clusters[column] = analyzer.merge_similar_clusters(cluster_terms=analyzer.__dict__['cluster_terms'], cluster_labels=analyzer.__dict__['cluster_labels'])
# Run the visualization
# fig = datamapplot.create_plot(
# analyzer.__dict__['reduced_embeddings'],
# analyzer.__dict__['cluster_labels'].astype(str),
# #label_font_size=11,
# label_wrap_width=20,
# use_medoids=True,
# )#.to_html(full_html=False, include_plotlyjs='cdn')
# st.pyplot(fig.savefig())
status.update(label=f"Processing {column} complete", expanded=False)
st.session_state['analyzers'] = analyzers
st.session_state['col_names'] = col_names
st.session_state['clusters'] = clusters
# save space
parsed = None
analyzers = None
col_names = None
clusters = None
if st.session_state['clusters'] is not None:
try:
new_data, parsed_responses = analyze_and_predict(st.session_state['parsed_responses'], st.session_state['analyzers'], st.session_state['col_names'], st.session_state['clusters'])
st.session_state['dataset'] = parsed_responses
st.session_state['new_data'] = new_data
st.session_state['data_processed'] = True
except Exception as e:
st.write(f"Error processing data: {e}")
if st.session_state['data_processed']:
try:
visualizer = UAPVisualizer(data=st.session_state['new_data'])
#new_data = pd.DataFrame() # Assuming new_data is prepared earlier in the code
fig2 = visualizer.plot_cramers_v_heatmap(data=st.session_state['new_data'], significance_level=0.05)
with st.status(f"Cramer's V Chart", expanded=True) as statuss:
st.pyplot(fig2)
statuss.update(label="Cramer's V chart plotted", expanded=False)
except Exception as e:
st.write(f"Error plotting Cramers V: {e}")
for i, column in enumerate(st.session_state['col_names']):
#if stateful_button(f"Show {column} clusters {i}", key=f"show_{column}_clusters"):
# if st.session_state['data_processed']:
# with st.status(f"Show clusters {column}", expanded=True) as stats:
# fig3 = st.session_state['analyzers'][i].plot_embeddings4(title=f"{column} clusters", cluster_terms=st.session_state['analyzers'][i].__dict__['cluster_terms'], cluster_labels=st.session_state['analyzers'][i].__dict__['cluster_labels'], reduced_embeddings=st.session_state['analyzers'][i].__dict__['reduced_embeddings'], column=f'Analyzer_{column}', data=st.session_state['new_data'])
# stats.update(label=f"Show clusters {column} complete", expanded=False)
if st.session_state['data_processed']:
with st.status(f"Show clusters {column}", expanded=True) as stats:
fig3 = st.session_state['analyzers'][i].plot_embeddings4(
title=f"{column} clusters",
cluster_terms=st.session_state['analyzers'][i].__dict__['cluster_terms'],
cluster_labels=st.session_state['analyzers'][i].__dict__['cluster_labels'],
reduced_embeddings=st.session_state['analyzers'][i].__dict__['reduced_embeddings'],
column=column, # Use the original column name here
data=st.session_state['parsed_responses'] # Use the original dataset here
)
stats.update(label=f"Show clusters {column} complete", expanded=False)
st.session_state['analysis_complete'] = True
# this will check if the dataframe is not empty
# if st.session_state['new_data'] is not None:
# parsed2 = st.session_state.get('dataset', pd.DataFrame())
# parsed2 = filter_dataframe(parsed2)
# col1, col2 = st.columns(2)
# st.dataframe(parsed2)
# with col1:
# col_parsed2 = st.selectbox("Which columns do you want to query?", parsed2.columns)
# with col2:
# GEMINI_KEY = st.text_input('Gemini APIs Key', GEMINI_KEY, type='password', help="Enter your Gemini API key")
# if col_parsed and GEMINI_KEY:
# selected_column_data2 = parsed2[col_parsed2].tolist()
# question2 = st.text_input("Ask a questions or leave empty for summarization")
# if st.button("Generate Query") and selected_column_data2:
# with st.status(f"Generating Query", expanded=True) as status:
# gemini_answer = gemini_query(question2, selected_column_data2, GEMINI_KEY)
# st.write(gemini_answer)
# st.session_state['gemini_answer'] = gemini_answer
if 'analysis_complete' in st.session_state and st.session_state['analysis_complete']:
ticked_analysis = st.checkbox('Query Processed Data')
if ticked_analysis:
if st.session_state['new_data'] is not None:
parsed2 = st.session_state.get('dataset', pd.DataFrame()).copy()
parsed2 = filter_dataframe(parsed2)
col1, col2 = st.columns(2)
st.dataframe(parsed2)
with col1:
col_parsed2 = st.selectbox("Which columns do you want to query?", parsed2.columns)
with col2:
GEMINI_KEY = st.text_input('Gemini APIs Key', value=GEMINI_KEY, type='password', help="Enter your Gemini API key")
if col_parsed2 and GEMINI_KEY:
selected_column_data2 = parsed2[col_parsed2].tolist()
question2 = st.text_input("Ask a questions or leave empty for summarization")
if st.button("Generate Queries") and selected_column_data2:
with st.status(f"Generating Query", expanded=True) as status:
gemini_answer = gemini_query(question2, selected_column_data2, GEMINI_KEY)
st.write(gemini_answer)
st.session_state['gemini_answer'] = gemini_answer