import streamlit as st #import cudf.pandas #cudf.pandas.install() import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns from uap_analyzer import UAPParser, UAPAnalyzer, UAPVisualizer # import ChartGen # from ChartGen import ChartGPT from Levenshtein import distance from sklearn.model_selection import train_test_split from sklearn.metrics import confusion_matrix from stqdm import stqdm stqdm.pandas() import streamlit.components.v1 as components from dateutil import parser from sentence_transformers import SentenceTransformer import torch import squarify import matplotlib.colors as mcolors import textwrap import datamapplot st.set_option('deprecation.showPyplotGlobalUse', False) from pandas.api.types import ( is_categorical_dtype, is_datetime64_any_dtype, is_numeric_dtype, is_object_dtype, ) def load_data(file_path, key='df'): return pd.read_hdf(file_path, key=key) def gemini_query(question, selected_data, gemini_key): if question == "": question = "Summarize the following data in relevant bullet points" import pathlib import textwrap import google.generativeai as genai from IPython.display import display from IPython.display import Markdown def to_markdown(text): text = text.replace('•', ' *') return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True)) # selected_data is a list # remove empty filtered = [str(x) for x in selected_data if str(x) != '' and x is not None] # make a string context = '\n'.join(filtered) genai.configure(api_key=gemini_key) query_model = genai.GenerativeModel('models/gemini-1.5-pro-latest') response = query_model.generate_content([f"{question}\n Answer based on this context: {context}\n\n"]) return(response.text) def plot_treemap(df, column, top_n=32): # Get the value counts and the top N labels value_counts = df[column].value_counts() top_labels = value_counts.iloc[:top_n].index # Use np.where to replace all values not in the top N with 'Other' revised_column = f'{column}_revised' df[revised_column] = np.where(df[column].isin(top_labels), df[column], 'Other') # Get the value counts including the 'Other' category sizes = df[revised_column].value_counts().values labels = df[revised_column].value_counts().index # Get a gradient of colors # colors = list(mcolors.TABLEAU_COLORS.values()) n_colors = len(sizes) colors = plt.cm.Oranges(np.linspace(0.3, 0.9, n_colors))[::-1] # Get % of each category percents = sizes / sizes.sum() # Prepare labels with percentages labels = [f'{label}\n {percent:.1%}' for label, percent in zip(labels, percents)] fig, ax = plt.subplots(figsize=(20, 12)) # Plot the treemap squarify.plot(sizes=sizes, label=labels, alpha=0.7, pad=True, color=colors, text_kwargs={'fontsize': 10}) ax = plt.gca() # Iterate over text elements and rectangles (patches) in the axes for color adjustment for text, rect in zip(ax.texts, ax.patches): background_color = rect.get_facecolor() r, g, b, _ = mcolors.to_rgba(background_color) brightness = np.average([r, g, b]) text.set_color('white' if brightness < 0.5 else 'black') # Adjust font size based on rectangle's area and wrap long text coef = 0.8 font_size = np.sqrt(rect.get_width() * rect.get_height()) * coef text.set_fontsize(font_size) wrapped_text = textwrap.fill(text.get_text(), width=20) text.set_text(wrapped_text) plt.axis('off') plt.gca().invert_yaxis() plt.gcf().set_size_inches(20, 12) fig.patch.set_alpha(0) ax.patch.set_alpha(0) return fig def plot_hist(df, column, bins=10, kde=True): fig, ax = plt.subplots(figsize=(12, 6)) sns.histplot(data=df, x=column, kde=True, bins=bins,color='orange') # set the ticks and frame in orange ax.spines['bottom'].set_color('orange') ax.spines['top'].set_color('orange') ax.spines['right'].set_color('orange') ax.spines['left'].set_color('orange') ax.xaxis.label.set_color('orange') ax.yaxis.label.set_color('orange') ax.tick_params(axis='x', colors='orange') ax.tick_params(axis='y', colors='orange') ax.title.set_color('orange') # Set transparent background fig.patch.set_alpha(0) ax.patch.set_alpha(0) return fig def plot_line(df, x_column, y_columns, figsize=(12, 10), color='orange', title=None, rolling_mean_value=2): import matplotlib.cm as cm # Sort the dataframe by the date column df = df.sort_values(by=x_column) # Calculate rolling mean for each y_column if rolling_mean_value: df[y_columns] = df[y_columns].rolling(len(df) // rolling_mean_value).mean() # Create the plot fig, ax = plt.subplots(figsize=figsize) colors = cm.Oranges(np.linspace(0.2, 1, len(y_columns))) # Plot each y_column as a separate line with a different color for i, y_column in enumerate(y_columns): df.plot(x=x_column, y=y_column, ax=ax, color=colors[i], label=y_column, linewidth=.5) # Rotate x-axis labels ax.set_xticklabels(ax.get_xticklabels(), rotation=30, ha='right') # Format x_column as date if it is if np.issubdtype(df[x_column].dtype, np.datetime64) or np.issubdtype(df[x_column].dtype, np.timedelta64): df[x_column] = pd.to_datetime(df[x_column]).dt.date # Set title, labels, and legend ax.set_title(title or f'{", ".join(y_columns)} over {x_column}', color=color, fontweight='bold') ax.set_xlabel(x_column, color=color) ax.set_ylabel(', '.join(y_columns), color=color) ax.spines['bottom'].set_color('orange') ax.spines['top'].set_color('orange') ax.spines['right'].set_color('orange') ax.spines['left'].set_color('orange') ax.xaxis.label.set_color('orange') ax.yaxis.label.set_color('orange') ax.tick_params(axis='x', colors='orange') ax.tick_params(axis='y', colors='orange') ax.title.set_color('orange') ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange') # Remove background fig.patch.set_alpha(0) ax.patch.set_alpha(0) return fig def plot_bar(df, x_column, y_column, figsize=(12, 10), color='orange', title=None, rotation=45): fig, ax = plt.subplots(figsize=figsize) sns.barplot(data=df, x=x_column, y=y_column, color=color, ax=ax) ax.set_title(title if title else f'{y_column} by {x_column}', color=color, fontweight='bold') ax.set_xlabel(x_column, color=color) ax.set_ylabel(y_column, color=color) ax.tick_params(axis='x', colors=color) ax.tick_params(axis='y', colors=color) plt.xticks(rotation=rotation) # Remove background fig.patch.set_alpha(0) ax.patch.set_alpha(0) ax.spines['bottom'].set_color('orange') ax.spines['top'].set_color('orange') ax.spines['right'].set_color('orange') ax.spines['left'].set_color('orange') ax.xaxis.label.set_color('orange') ax.yaxis.label.set_color('orange') ax.tick_params(axis='x', colors='orange') ax.tick_params(axis='y', colors='orange') ax.title.set_color('orange') ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange') return fig def plot_grouped_bar(df, x_columns, y_column, figsize=(12, 10), colors=None, title=None): fig, ax = plt.subplots(figsize=figsize) width = 0.8 / len(x_columns) # the width of the bars x = np.arange(len(df)) # the label locations for i, x_column in enumerate(x_columns): sns.barplot(data=df, x=x, y=y_column, color=colors[i] if colors else None, ax=ax, width=width, label=x_column) x += width # add the width of the bar to the x position for the next bar ax.set_title(title if title else f'{y_column} by {", ".join(x_columns)}', color='orange', fontweight='bold') ax.set_xlabel('Groups', color='orange') ax.set_ylabel(y_column, color='orange') ax.set_xticks(x - width * len(x_columns) / 2) ax.set_xticklabels(df.index) ax.tick_params(axis='x', colors='orange') ax.tick_params(axis='y', colors='orange') # Remove background fig.patch.set_alpha(0) ax.patch.set_alpha(0) ax.spines['bottom'].set_color('orange') ax.spines['top'].set_color('orange') ax.spines['right'].set_color('orange') ax.spines['left'].set_color('orange') ax.xaxis.label.set_color('orange') ax.yaxis.label.set_color('orange') ax.title.set_color('orange') ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange') return fig def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame: """ Adds a UI on top of a dataframe to let viewers filter columns Args: df (pd.DataFrame): Original dataframe Returns: pd.DataFrame: Filtered dataframe """ title_font = "Arial" body_font = "Arial" title_size = 32 colors = ["red", "green", "blue"] interpretation = False extract_docx = False title = "My Chart" regex = ".*" img_path = 'default_image.png' #try: # modify = st.checkbox("Add filters on raw data") #except: # try: # modify = st.checkbox("Add filters on processed data") # except: # try: # modify = st.checkbox("Add filters on parsed data") # except: # pass #if not modify: # return df df_ = df.copy() # Try to convert datetimes into a standard format (datetime, no timezone) #modification_container = st.container() #with modification_container: try: to_filter_columns = st.multiselect("Filter dataframe on", df_.columns) except: try: to_filter_columns = st.multiselect("Filter dataframe", df_.columns) except: try: to_filter_columns = st.multiselect("Filter the dataframe on", df_.columns) except: pass date_column = None filtered_columns = [] for column in to_filter_columns: left, right = st.columns((1, 20)) # Treat columns with < 200 unique values as categorical if not date or numeric if is_categorical_dtype(df_[column]) or (df_[column].nunique() < 120 and not is_datetime64_any_dtype(df_[column]) and not is_numeric_dtype(df_[column])): user_cat_input = right.multiselect( f"Values for {column}", df_[column].value_counts().index.tolist(), default=list(df_[column].value_counts().index) ) df_ = df_[df_[column].isin(user_cat_input)] filtered_columns.append(column) with st.status(f"Category Distribution: {column}", expanded=False) as stat: st.pyplot(plot_treemap(df_, column)) elif is_numeric_dtype(df_[column]): _min = float(df_[column].min()) _max = float(df_[column].max()) step = (_max - _min) / 100 user_num_input = right.slider( f"Values for {column}", min_value=_min, max_value=_max, value=(_min, _max), step=step, ) df_ = df_[df_[column].between(*user_num_input)] filtered_columns.append(column) # Chart_GPT = ChartGPT(df_, title_font, body_font, title_size, # colors, interpretation, extract_docx, img_path) with st.status(f"Numerical Distribution: {column}", expanded=False) as stat_: st.pyplot(plot_hist(df_, column, bins=int(round(len(df_[column].unique())-1)/2))) elif is_object_dtype(df_[column]): try: df_[column] = pd.to_datetime(df_[column], infer_datetime_format=True, errors='coerce') except Exception: try: df_[column] = df_[column].apply(parser.parse) except Exception: pass if is_datetime64_any_dtype(df_[column]): df_[column] = df_[column].dt.tz_localize(None) min_date = df_[column].min().date() max_date = df_[column].max().date() user_date_input = right.date_input( f"Values for {column}", value=(min_date, max_date), min_value=min_date, max_value=max_date, ) # if len(user_date_input) == 2: # start_date, end_date = user_date_input # df_ = df_.loc[df_[column].dt.date.between(start_date, end_date)] if len(user_date_input) == 2: user_date_input = tuple(map(pd.to_datetime, user_date_input)) start_date, end_date = user_date_input # Determine the most appropriate time unit for plot time_units = { 'year': df_[column].dt.year, 'month': df_[column].dt.to_period('M'), 'day': df_[column].dt.date } unique_counts = {unit: col.nunique() for unit, col in time_units.items()} closest_to_36 = min(unique_counts, key=lambda k: abs(unique_counts[k] - 36)) # Group by the most appropriate time unit and count occurrences grouped = df_.groupby(time_units[closest_to_36]).size().reset_index(name='count') grouped.columns = [column, 'count'] # Create a complete date range if closest_to_36 == 'year': date_range = pd.date_range(start=f"{start_date.year}-01-01", end=f"{end_date.year}-12-31", freq='YS') elif closest_to_36 == 'month': date_range = pd.date_range(start=start_date.replace(day=1), end=end_date + pd.offsets.MonthEnd(0), freq='MS') else: # day date_range = pd.date_range(start=start_date, end=end_date, freq='D') # Create a DataFrame with the complete date range complete_range = pd.DataFrame({column: date_range}) # Convert the date column to the appropriate format based on closest_to_36 if closest_to_36 == 'year': complete_range[column] = complete_range[column].dt.year elif closest_to_36 == 'month': complete_range[column] = complete_range[column].dt.to_period('M') # Merge the complete range with the grouped data final_data = pd.merge(complete_range, grouped, on=column, how='left').fillna(0) with st.status(f"Date Distributions: {column}", expanded=False) as stat: try: st.pyplot(plot_bar(final_data, column, 'count')) except Exception as e: st.error(f"Error plotting bar chart: {e}") df_ = df_.loc[df_[column].between(start_date, end_date)] date_column = column if date_column and filtered_columns: numeric_columns = [col for col in filtered_columns if is_numeric_dtype(df_[col])] if numeric_columns: fig = plot_line(df_, date_column, numeric_columns) #st.pyplot(fig) # now to deal with categorical columns categorical_columns = [col for col in filtered_columns if is_categorical_dtype(df_[col])] if categorical_columns: fig2 = plot_bar(df_, date_column, categorical_columns[0]) #st.pyplot(fig2) with st.status(f"Date Distribution: {column}", expanded=False) as stat: try: st.pyplot(fig) except Exception as e: st.error(f"Error plotting line chart: {e}") pass try: st.pyplot(fig2) except Exception as e: st.error(f"Error plotting bar chart: {e}") else: user_text_input = right.text_input( f"Substring or regex in {column}", ) if user_text_input: df_ = df_[df_[column].astype(str).str.contains(user_text_input)] # write len of df after filtering with % of original st.write(f"{len(df_)} rows ({len(df_) / len(df) * 100:.2f}%)") return df_ def merge_clusters(df, column): cluster_terms_ = df.__dict__['cluster_terms'] cluster_labels_ = df.__dict__['cluster_labels'] label_name_map = {label: cluster_terms_[label] for label in set(cluster_labels_)} merge_map = {} # Iterate over term pairs and decide on merging based on the distance for idx, term1 in enumerate(cluster_terms_): for jdx, term2 in enumerate(cluster_terms_): if idx < jdx and distance(term1, term2) <= 3: # Adjust threshold as needed # Decide to merge labels corresponding to jdx into labels corresponding to idx # Find labels corresponding to jdx and idx labels_to_merge = [label for label, term_index in enumerate(cluster_labels_) if term_index == jdx] for label in labels_to_merge: merge_map[label] = idx # Map the label to use the term index of term1 # Update the analyzer with the merged numeric labels updated_cluster_labels_ = [merge_map[label] if label in merge_map else label for label in cluster_labels_] df.__dict__['cluster_labels'] = updated_cluster_labels_ # Optional: Update string labels to reflect merged labels updated_string_labels = [cluster_terms_[label] for label in updated_cluster_labels_] df.__dict__['string_labels'] = updated_string_labels return updated_string_labels def analyze_and_predict(data, analyzers, col_names, clusters): visualizer = UAPVisualizer() new_data = pd.DataFrame() for i, column in enumerate(col_names): #new_data[f'Analyzer_{column}'] = analyzers[column].__dict__['cluster_labels'] new_data[f'Analyzer_{column}'] = clusters[column] data[f'Analyzer_{column}'] = clusters[column] #data[f'Analyzer_{column}'] = analyzer.__dict__['cluster_labels'] print(f"Cluster terms extracted for {column}") for col in data.columns: if 'Analyzer' in col: data[col] = data[col].astype('category') new_data = new_data.fillna('null').astype('category') data_nums = new_data.apply(lambda x: x.cat.codes) for col in data_nums.columns: try: categories = new_data[col].cat.categories x_train, x_test, y_train, y_test = train_test_split(data_nums.drop(columns=[col]), data_nums[col], test_size=0.2, random_state=42) bst, accuracy, preds = visualizer.train_xgboost(x_train, y_train, x_test, y_test, len(categories)) fig = visualizer.plot_results(new_data, bst, x_test, y_test, preds, categories, accuracy, col) with st.status(f"Charts Analyses: {col}", expanded=True) as status: st.pyplot(fig) status.update(label=f"Chart Processed: {col}", expanded=False) except Exception as e: print(f"Error processing {col}: {e}") continue return new_data, data from config import API_KEY, GEMINI_KEY, FORMAT_LONG with torch.no_grad(): torch.cuda.empty_cache() #st.set_page_config( # page_title="UAP ANALYSIS", # page_icon=":alien:", # layout="wide", # initial_sidebar_state="expanded", #) st.title('UAP Analysis Dashboard') # Initialize session state if 'analyzers' not in st.session_state: st.session_state['analyzers'] = [] if 'col_names' not in st.session_state: st.session_state['col_names'] = [] if 'clusters' not in st.session_state: st.session_state['clusters'] = {} if 'new_data' not in st.session_state: st.session_state['new_data'] = pd.DataFrame() if 'dataset' not in st.session_state: st.session_state['dataset'] = pd.DataFrame() if 'data_processed' not in st.session_state: st.session_state['data_processed'] = False if 'stage' not in st.session_state: st.session_state['stage'] = 0 if 'filtered_data' not in st.session_state: st.session_state['filtered_data'] = None if 'gemini_answer' not in st.session_state: st.session_state['gemini_answer'] = None if 'parsed_responses' not in st.session_state: st.session_state['parsed_responses'] = None # Load dataset data_path = 'parsed_files_distance_embeds.h5' my_dataset = st.file_uploader("Upload Parsed DataFrame", type=["csv", "xlsx"]) if my_dataset is not None: if parsed: # save space by cleaning default dataset parsed = None try: if my_dataset.type == "text/csv": data = pd.read_csv(my_dataset) elif my_dataset.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": data = pd.read_excel(my_dataset) else: st.error("Unsupported file type. Please upload a CSV, Excel or HD5 file.") st.stop() parser = filter_dataframe(data) st.session_state['parsed_responses'] = parser st.dataframe(parser) st.success(f"Successfully loaded and displayed data from {my_dataset.name}") except Exception as e: st.error(f"An error occurred while reading the file: {e}") else: parsed = load_data(data_path).drop(columns=['embeddings']) parsed_responses = filter_dataframe(parsed) st.session_state['parsed_responses'] = parsed_responses st.dataframe(parsed_responses) col1, col2 = st.columns(2) with col1: col_parsed = st.selectbox("Which column do you want to query?", st.session_state['parsed_responses'].columns) with col2: GEMINI_KEY = st.text_input('Gemini API Key', value=GEMINI_KEY, type='password', help="Enter your Gemini API key") if col_parsed and GEMINI_KEY: selected_column_data = st.session_state['parsed_responses'][col_parsed].tolist() question = st.text_input("Ask a question or leave empty for summarization") if st.button("Generate Query") and selected_column_data: st.write(gemini_query(question, selected_column_data, GEMINI_KEY)) st.session_state['stage'] = 1 if st.session_state['stage'] > 0 : with st.form(border=True, key='Select Columns for Analysis'): columns_to_analyze = st.multiselect( label='Select columns to analyze', options=st.session_state['parsed_responses'].columns ) if st.form_submit_button("Process Data"): if columns_to_analyze: analyzers = [] col_names = [] clusters = {} for column in columns_to_analyze: with torch.no_grad(): with st.status(f"Processing {column}", expanded=True) as status: analyzer = UAPAnalyzer(st.session_state['parsed_responses'], column) st.write(f"Processing {column}...") analyzer.preprocess_data(top_n=32) st.write("Reducing dimensionality...") analyzer.reduce_dimensionality(method='UMAP', n_components=2, n_neighbors=15, min_dist=0.1) st.write("Clustering data...") analyzer.cluster_data(method='HDBSCAN', min_cluster_size=15) analyzer.get_tf_idf_clusters(top_n=3) st.write("Naming clusters...") analyzers.append(analyzer) col_names.append(column) clusters[column] = analyzer.merge_similar_clusters(cluster_terms=analyzer.__dict__['cluster_terms'], cluster_labels=analyzer.__dict__['cluster_labels']) # Run the visualization # fig = datamapplot.create_plot( # analyzer.__dict__['reduced_embeddings'], # analyzer.__dict__['cluster_labels'].astype(str), # #label_font_size=11, # label_wrap_width=20, # use_medoids=True, # )#.to_html(full_html=False, include_plotlyjs='cdn') # st.pyplot(fig.savefig()) status.update(label=f"Processing {column} complete", expanded=False) st.session_state['analyzers'] = analyzers st.session_state['col_names'] = col_names st.session_state['clusters'] = clusters # save space parsed = None analyzers = None col_names = None clusters = None if st.session_state['clusters'] is not None: try: new_data, parsed_responses = analyze_and_predict(st.session_state['parsed_responses'], st.session_state['analyzers'], st.session_state['col_names'], st.session_state['clusters']) st.session_state['dataset'] = parsed_responses st.session_state['new_data'] = new_data st.session_state['data_processed'] = True except Exception as e: st.write(f"Error processing data: {e}") if st.session_state['data_processed']: try: visualizer = UAPVisualizer(data=st.session_state['new_data']) #new_data = pd.DataFrame() # Assuming new_data is prepared earlier in the code fig2 = visualizer.plot_cramers_v_heatmap(data=st.session_state['new_data'], significance_level=0.05) with st.status(f"Cramer's V Chart", expanded=True) as statuss: st.pyplot(fig2) statuss.update(label="Cramer's V chart plotted", expanded=False) except Exception as e: st.write(f"Error plotting Cramers V: {e}") for i, column in enumerate(st.session_state['col_names']): #if stateful_button(f"Show {column} clusters {i}", key=f"show_{column}_clusters"): # if st.session_state['data_processed']: # with st.status(f"Show clusters {column}", expanded=True) as stats: # fig3 = st.session_state['analyzers'][i].plot_embeddings4(title=f"{column} clusters", cluster_terms=st.session_state['analyzers'][i].__dict__['cluster_terms'], cluster_labels=st.session_state['analyzers'][i].__dict__['cluster_labels'], reduced_embeddings=st.session_state['analyzers'][i].__dict__['reduced_embeddings'], column=f'Analyzer_{column}', data=st.session_state['new_data']) # stats.update(label=f"Show clusters {column} complete", expanded=False) if st.session_state['data_processed']: with st.status(f"Show clusters {column}", expanded=True) as stats: fig3 = st.session_state['analyzers'][i].plot_embeddings4( title=f"{column} clusters", cluster_terms=st.session_state['analyzers'][i].__dict__['cluster_terms'], cluster_labels=st.session_state['analyzers'][i].__dict__['cluster_labels'], reduced_embeddings=st.session_state['analyzers'][i].__dict__['reduced_embeddings'], column=column, # Use the original column name here data=st.session_state['parsed_responses'] # Use the original dataset here ) stats.update(label=f"Show clusters {column} complete", expanded=False) st.session_state['analysis_complete'] = True # this will check if the dataframe is not empty # if st.session_state['new_data'] is not None: # parsed2 = st.session_state.get('dataset', pd.DataFrame()) # parsed2 = filter_dataframe(parsed2) # col1, col2 = st.columns(2) # st.dataframe(parsed2) # with col1: # col_parsed2 = st.selectbox("Which columns do you want to query?", parsed2.columns) # with col2: # GEMINI_KEY = st.text_input('Gemini APIs Key', GEMINI_KEY, type='password', help="Enter your Gemini API key") # if col_parsed and GEMINI_KEY: # selected_column_data2 = parsed2[col_parsed2].tolist() # question2 = st.text_input("Ask a questions or leave empty for summarization") # if st.button("Generate Query") and selected_column_data2: # with st.status(f"Generating Query", expanded=True) as status: # gemini_answer = gemini_query(question2, selected_column_data2, GEMINI_KEY) # st.write(gemini_answer) # st.session_state['gemini_answer'] = gemini_answer if 'analysis_complete' in st.session_state and st.session_state['analysis_complete']: ticked_analysis = st.checkbox('Query Processed Data') if ticked_analysis: if st.session_state['new_data'] is not None: parsed2 = st.session_state.get('dataset', pd.DataFrame()).copy() parsed2 = filter_dataframe(parsed2) col1, col2 = st.columns(2) st.dataframe(parsed2) with col1: col_parsed2 = st.selectbox("Which columns do you want to query?", parsed2.columns) with col2: GEMINI_KEY = st.text_input('Gemini APIs Key', value=GEMINI_KEY, type='password', help="Enter your Gemini API key") if col_parsed2 and GEMINI_KEY: selected_column_data2 = parsed2[col_parsed2].tolist() question2 = st.text_input("Ask a questions or leave empty for summarization") if st.button("Generate Queries") and selected_column_data2: with st.status(f"Generating Query", expanded=True) as status: gemini_answer = gemini_query(question2, selected_column_data2, GEMINI_KEY) st.write(gemini_answer) st.session_state['gemini_answer'] = gemini_answer