import streamlit as st import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns from uap_analyzer import UAPParser, UAPAnalyzer, UAPVisualizer # import ChartGen # from ChartGen import ChartGPT from Levenshtein import distance from sklearn.model_selection import train_test_split from sklearn.metrics import confusion_matrix from stqdm import stqdm stqdm.pandas() import streamlit.components.v1 as components from dateutil import parser from sentence_transformers import SentenceTransformer import torch import squarify import matplotlib.colors as mcolors import textwrap import datamapplot import openai from openai import OpenAI import os import json # this is a test comment import plotly.graph_objects as go st.set_option('deprecation.showPyplotGlobalUse', False) from pandas.api.types import ( is_categorical_dtype, is_datetime64_any_dtype, is_numeric_dtype, is_object_dtype, ) def load_data(file_path, key='df'): return pd.read_hdf(file_path, key=key) def gemini_query(question, selected_data, gemini_key): if question == "": question = "Summarize the following data in relevant bullet points" import pathlib import textwrap import google.generativeai as genai from IPython.display import display from IPython.display import Markdown def to_markdown(text): text = text.replace('•', ' *') return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True)) # selected_data is a list # remove empty filtered = [str(x) for x in selected_data if str(x) != '' and x is not None] # make a string context = '\n'.join(filtered) genai.configure(api_key=gemini_key) query_model = genai.GenerativeModel('models/gemini-1.5-pro-latest') response = query_model.generate_content([f"{question}\n Answer based on this context: {context}\n\n"]) return(response.text) def plot_treemap(df, column, top_n=32): # Get the value counts and the top N labels value_counts = df[column].value_counts() top_labels = value_counts.iloc[:top_n].index # Use np.where to replace all values not in the top N with 'Other' revised_column = f'{column}_revised' df[revised_column] = np.where(df[column].isin(top_labels), df[column], 'Other') # Get the value counts including the 'Other' category sizes = df[revised_column].value_counts().values labels = df[revised_column].value_counts().index # Get a gradient of colors # colors = list(mcolors.TABLEAU_COLORS.values()) n_colors = len(sizes) colors = plt.cm.Oranges(np.linspace(0.3, 0.9, n_colors))[::-1] # Get % of each category percents = sizes / sizes.sum() # Prepare labels with percentages labels = [f'{label}\n {percent:.1%}' for label, percent in zip(labels, percents)] fig, ax = plt.subplots(figsize=(20, 12)) # Plot the treemap squarify.plot(sizes=sizes, label=labels, alpha=0.7, pad=True, color=colors, text_kwargs={'fontsize': 10}) ax = plt.gca() # Iterate over text elements and rectangles (patches) in the axes for color adjustment for text, rect in zip(ax.texts, ax.patches): background_color = rect.get_facecolor() r, g, b, _ = mcolors.to_rgba(background_color) brightness = np.average([r, g, b]) text.set_color('white' if brightness < 0.5 else 'black') # Adjust font size based on rectangle's area and wrap long text st.set_option('deprecation.showPyplotGlobalUse', False) from pandas.api.types import ( is_categorical_dtype, is_datetime64_any_dtype, is_numeric_dtype, is_object_dtype, ) class CachedUAPParser(UAPParser): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if 'parsed_responses' not in st.session_state: st.session_state['parsed_responses'] = {} def parse_responses(self): parsed_responses = {} not_parsed = 0 try: for k, v in self.responses.items(): try: parsed_responses[k] = json.loads(v) except: try: parsed_responses[k] = json.loads(v.replace("'", '"')) except: not_parsed += 1 # Update the cached responses st.session_state['parsed_responses'] = parsed_responses except Exception as e: st.error(f"Error parsing responses: {e}") st.write(f"Number of unparsed responses: {not_parsed}") st.write(f"Number of parsed responses: {len(parsed_responses)}") return st.session_state['parsed_responses'] def responses_to_df(self, col, parsed_responses): try: parsed_df = pd.DataFrame(parsed_responses).T if col is not None: parsed_df2 = pd.json_normalize(parsed_df[col]) parsed_df2.index = parsed_df.index else: parsed_df2 = pd.json_normalize(parsed_df) parsed_df2.index = parsed_df.index # Convert problematic columns to string for column in parsed_df2.columns: if parsed_df2[column].dtype == 'object': parsed_df2[column] = parsed_df2[column].astype(str) return parsed_df2 except Exception as e: st.error(f"Error converting responses to DataFrame: {e}") return pd.DataFrame() # Return an empty DataFrame if conversion fails def load_data(file_path, key='df'): return pd.read_hdf(file_path, key=key) def gemini_query(question, selected_data, gemini_key): if question == "": question = "Summarize the following data in relevant bullet points" import pathlib import textwrap import google.generativeai as genai from IPython.display import display from IPython.display import Markdown def to_markdown(text): text = text.replace('•', ' *') return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True)) # selected_data is a list # remove empty filtered = [str(x) for x in selected_data if str(x) != '' and x is not None] # make a string context = '\n'.join(filtered) genai.configure(api_key=gemini_key) query_model = genai.GenerativeModel('models/gemini-1.5-pro-latest') response = query_model.generate_content([f"{question}\n Answer based on this context: {context}\n\n"]) return(response.text) def plot_hist(df, column, bins=10, kde=True): fig, ax = plt.subplots(figsize=(12, 6)) sns.histplot(data=df, x=column, kde=True, bins=bins,color='orange') # set the ticks and frame in orange ax.spines['bottom'].set_color('orange') ax.spines['top'].set_color('orange') ax.spines['right'].set_color('orange') ax.spines['left'].set_color('orange') ax.xaxis.label.set_color('orange') ax.yaxis.label.set_color('orange') ax.tick_params(axis='x', colors='orange') ax.tick_params(axis='y', colors='orange') ax.title.set_color('orange') # Set transparent background fig.patch.set_alpha(0) ax.patch.set_alpha(0) return fig def is_api_key_valid(api_key, model='gpt-4o-mini'): try: os.environ['OPENAI_API_KEY'] = api_key client = OpenAI() response = client.chat.completions.create( model=model, messages=[{"role": "user", "content": 'Say Hello World!'}]) text = response.choices[0].message.content if len(text) >= 0: return True except Exception as e: st.error(f'Error with the API key :{e}') return False def download_json(data): json_str = json.dumps(data, indent=2) return json_str def convert_cached_data_to_df(parser): if 'parsed_responses' in st.session_state: #parser = CachedUAPParser(api_key=API_KEY, model='gpt-4o-mini') try: responses_df = parser.responses_to_df('sightingDetails', st.session_state['parsed_responses']) except Exception as e: st.warning(f"Error parsing with 'sightingDetails': {e}") responses_df = parser.responses_to_df(None, st.session_state['parsed_responses']) if not responses_df.empty: st.dataframe(responses_df) st.session_state['parsed_responses_df'] = responses_df.copy() st.success("Successfully converted cached data to DataFrame.") else: st.error("Failed to create DataFrame from cached responses.") else: st.warning("No cached data available. Please parse the dataset first.") def plot_line(df, x_column, y_columns, figsize=(12, 10), color='orange', title=None, rolling_mean_value=2): import matplotlib.cm as cm # Sort the dataframe by the date column df = df.sort_values(by=x_column) # Calculate rolling mean for each y_column if rolling_mean_value: df[y_columns] = df[y_columns].rolling(len(df) // rolling_mean_value).mean() # Create the plot fig, ax = plt.subplots(figsize=figsize) colors = cm.Oranges(np.linspace(0.2, 1, len(y_columns))) # Plot each y_column as a separate line with a different color for i, y_column in enumerate(y_columns): df.plot(x=x_column, y=y_column, ax=ax, color=colors[i], label=y_column, linewidth=.5) # Rotate x-axis labels ax.set_xticklabels(ax.get_xticklabels(), rotation=30, ha='right') # Format x_column as date if it is if np.issubdtype(df[x_column].dtype, np.datetime64) or np.issubdtype(df[x_column].dtype, np.timedelta64): df[x_column] = pd.to_datetime(df[x_column]).dt.date # Set title, labels, and legend ax.set_title(title or f'{", ".join(y_columns)} over {x_column}', color=color, fontweight='bold') ax.set_xlabel(x_column, color=color) ax.set_ylabel(', '.join(y_columns), color=color) ax.spines['bottom'].set_color('orange') ax.spines['top'].set_color('orange') ax.spines['right'].set_color('orange') ax.spines['left'].set_color('orange') ax.xaxis.label.set_color('orange') ax.yaxis.label.set_color('orange') ax.tick_params(axis='x', colors='orange') ax.tick_params(axis='y', colors='orange') ax.title.set_color('orange') ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange') # Remove background fig.patch.set_alpha(0) ax.patch.set_alpha(0) return fig def plot_bar(df, x_column, y_column, figsize=(12, 10), color='orange', title=None): fig, ax = plt.subplots(figsize=figsize) sns.barplot(data=df, x=x_column, y=y_column, color=color, ax=ax) # Rotate x-axis labels ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right') ax.set_title(title if title else f'{y_column} by {x_column}', color=color, fontweight='bold') ax.set_xlabel(x_column, color=color) ax.set_ylabel(y_column, color=color) ax.tick_params(axis='x', colors=color) ax.tick_params(axis='y', colors=color) # Remove background fig.patch.set_alpha(0) ax.patch.set_alpha(0) ax.spines['bottom'].set_color('orange') ax.spines['top'].set_color('orange') ax.spines['right'].set_color('orange') ax.spines['left'].set_color('orange') ax.xaxis.label.set_color('orange') ax.yaxis.label.set_color('orange') ax.tick_params(axis='x', colors='orange') ax.tick_params(axis='y', colors='orange') ax.title.set_color('orange') ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange') return fig def plot_grouped_bar(df, x_columns, y_column, figsize=(12, 10), colors=None, title=None): fig, ax = plt.subplots(figsize=figsize) width = 0.8 / len(x_columns) # the width of the bars x = np.arange(len(df)) # the label locations for i, x_column in enumerate(x_columns): sns.barplot(data=df, x=x, y=y_column, color=colors[i] if colors else None, ax=ax, width=width, label=x_column) x += width # add the width of the bar to the x position for the next bar ax.set_title(title if title else f'{y_column} by {", ".join(x_columns)}', color='orange', fontweight='bold') ax.set_xlabel('Groups', color='orange') ax.set_ylabel(y_column, color='orange') ax.set_xticks(x - width * len(x_columns) / 2) ax.set_xticklabels(df.index) ax.tick_params(axis='x', colors='orange') ax.tick_params(axis='y', colors='orange') # Remove background fig.patch.set_alpha(0) ax.patch.set_alpha(0) ax.spines['bottom'].set_color('orange') ax.spines['top'].set_color('orange') ax.spines['right'].set_color('orange') ax.spines['left'].set_color('orange') ax.xaxis.label.set_color('orange') ax.yaxis.label.set_color('orange') ax.title.set_color('orange') ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange') return fig @st.cache_data def convert_df(df): # IMPORTANT: Cache the conversion to prevent computation on every rerun try: csv = df.to_csv().encode("utf-8") except: csv = df.to_csv().encode("utf-8-sig") return csv def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame: """ Adds a UI on top of a dataframe to let viewers filter columns Args: df (pd.DataFrame): Original dataframe Returns: pd.DataFrame: Filtered dataframe """ title_font = "Arial" body_font = "Arial" title_size = 32 colors = ["red", "green", "blue"] interpretation = False extract_docx = False title = "My Chart" regex = ".*" img_path = 'default_image.png' #try: # modify = st.checkbox("Add filters on raw data") #except: # try: # modify = st.checkbox("Add filters on processed data") # except: # try: # modify = st.checkbox("Add filters on parsed data") # except: # pass #if not modify: # return df df_ = df.copy() # Try to convert datetimes into a standard format (datetime, no timezone) #modification_container = st.container() #with modification_container: to_filter_columns = st.multiselect("Filter dataframe on", df_.columns) date_column = None filtered_columns = [] for column in to_filter_columns: left, right = st.columns((1, 20)) # Treat columns with < 200 unique values as categorical if not date or numeric if is_categorical_dtype(df_[column]) or (df_[column].nunique() < 120 and not is_datetime64_any_dtype(df_[column]) and not is_numeric_dtype(df_[column])): user_cat_input = right.multiselect( f"Values for {column}", df_[column].value_counts().index.tolist(), default=list(df_[column].value_counts().index) ) df_ = df_[df_[column].isin(user_cat_input)] filtered_columns.append(column) with st.status(f"Category Distribution: {column}", expanded=False) as stat: st.pyplot(plot_treemap(df_, column)) elif is_numeric_dtype(df_[column]): _min = float(df_[column].min()) _max = float(df_[column].max()) step = (_max - _min) / 100 user_num_input = right.slider( f"Values for {column}", min_value=_min, max_value=_max, value=(_min, _max), step=step, ) df_ = df_[df_[column].between(*user_num_input)] filtered_columns.append(column) # Chart_GPT = ChartGPT(df_, title_font, body_font, title_size, # colors, interpretation, extract_docx, img_path) with st.status(f"Numerical Distribution: {column}", expanded=False) as stat_: st.pyplot(plot_hist(df_, column, bins=int(round(len(df_[column].unique())-1)/2))) elif is_object_dtype(df_[column]): try: df_[column] = pd.to_datetime(df_[column], infer_datetime_format=True, errors='coerce') except Exception: try: df_[column] = df_[column].apply(parser.parse) except Exception: pass if is_datetime64_any_dtype(df_[column]): df_[column] = df_[column].dt.tz_localize(None) min_date = df_[column].min().date() max_date = df_[column].max().date() user_date_input = right.date_input( f"Values for {column}", value=(min_date, max_date), min_value=min_date, max_value=max_date, ) # if len(user_date_input) == 2: # start_date, end_date = user_date_input # df_ = df_.loc[df_[column].dt.date.between(start_date, end_date)] if len(user_date_input) == 2: user_date_input = tuple(map(pd.to_datetime, user_date_input)) start_date, end_date = user_date_input # Determine the most appropriate time unit for plot time_units = { 'year': df_[column].dt.year, 'month': df_[column].dt.to_period('M'), 'day': df_[column].dt.date } unique_counts = {unit: col.nunique() for unit, col in time_units.items()} closest_to_36 = min(unique_counts, key=lambda k: abs(unique_counts[k] - 36)) # Group by the most appropriate time unit and count occurrences grouped = df_.groupby(time_units[closest_to_36]).size().reset_index(name='count') grouped.columns = [column, 'count'] # Create a complete date range if closest_to_36 == 'year': date_range = pd.date_range(start=f"{start_date.year}-01-01", end=f"{end_date.year}-12-31", freq='YS') elif closest_to_36 == 'month': date_range = pd.date_range(start=start_date.replace(day=1), end=end_date + pd.offsets.MonthEnd(0), freq='MS') else: # day date_range = pd.date_range(start=start_date, end=end_date, freq='D') # Create a DataFrame with the complete date range complete_range = pd.DataFrame({column: date_range}) # Convert the date column to the appropriate format based on closest_to_36 if closest_to_36 == 'year': complete_range[column] = complete_range[column].dt.year elif closest_to_36 == 'month': complete_range[column] = complete_range[column].dt.to_period('M') # Merge the complete range with the grouped data final_data = pd.merge(complete_range, grouped, on=column, how='left').fillna(0) with st.status(f"Date Distributions: {column}", expanded=False) as stat: try: st.pyplot(plot_bar(final_data, column, 'count')) except Exception as e: st.error(f"Error plotting bar chart: {e}") df_ = df_.loc[df_[column].between(start_date, end_date)] date_column = column if date_column and filtered_columns: numeric_columns = [col for col in filtered_columns if is_numeric_dtype(df_[col])] if numeric_columns: fig = plot_line(df_, date_column, numeric_columns) #st.pyplot(fig) # now to deal with categorical columns categorical_columns = [col for col in filtered_columns if is_categorical_dtype(df_[col])] if categorical_columns: fig2 = plot_bar(df_, date_column, categorical_columns[0]) #st.pyplot(fig2) with st.status(f"Date Distribution: {column}", expanded=False) as stat: try: st.pyplot(fig) except Exception as e: st.error(f"Error plotting line chart: {e}") pass try: st.pyplot(fig2) except Exception as e: st.error(f"Error plotting bar chart: {e}") else: user_text_input = right.text_input( f"Substring or regex in {column}", ) if user_text_input: df_ = df_[df_[column].astype(str).str.contains(user_text_input)] # write len of df after filtering with % of original st.write(f"{len(df_)} rows ({len(df_) / len(df) * 100:.2f}%)") return df_ from config import FORMAT_LONG OPENAI_KEY = st.secrets["OPENAI_KEY"] GEMINI_KEY = st.secrets["GEMINI_KEY"] with torch.no_grad(): torch.cuda.empty_cache() #st.set_page_config( # page_title="UAP ANALYSIS", # page_icon=":alien:", # layout="wide", # initial_sidebar_state="expanded", #) st.title('UAP Feature Extraction') # Initialize session state if 'analyzers' not in st.session_state: st.session_state['analyzers'] = [] if 'col_names' not in st.session_state: st.session_state['col_names'] = [] if 'clusters' not in st.session_state: st.session_state['clusters'] = {} if 'new_data' not in st.session_state: st.session_state['new_data'] = pd.DataFrame() if 'dataset' not in st.session_state: st.session_state['dataset'] = pd.DataFrame() if 'data_processed' not in st.session_state: st.session_state['data_processed'] = False if 'stage' not in st.session_state: st.session_state['stage'] = 0 if 'filtered_data' not in st.session_state: st.session_state['filtered_data'] = None if 'gemini_answer' not in st.session_state: st.session_state['gemini_answer'] = None if 'parsed_responses' not in st.session_state: st.session_state['parsed_responses'] = None if 'parsed_responses_df' not in st.session_state: st.session_state['parsed_responses_df'] = None if 'json_format' not in st.session_state: st.session_state['json_format'] = None if 'api_key_valid' not in st.session_state: st.session_state['api_key_valid'] = False if 'previous_api_key' not in st.session_state: st.session_state['previous_api_key'] = None # Unparsed data #unparsed_tickbox = st.checkbox('Data Parsing') #if unparsed_tickbox: unparsed = st.file_uploader("Upload Raw DataFrame", type=["csv", "xlsx"]) if unparsed is not None: try: data = pd.read_csv(unparsed) if unparsed.type == "text/csv" else pd.read_excel(unparsed) filtered_data = filter_dataframe(data) st.dataframe(filtered_data) except Exception as e: st.error(f"An error occurred while reading the file: {e}") modify_json = st.checkbox('Custom JSON') API_KEY = st.text_input('OpenAI API Key', API_KEY, type='password', help="Enter your OpenAI API key") if modify_json: FORMAT_LONG = st.text_area('Custom JSON', FORMAT_LONG, height=500) st.download_button("Save Format", FORMAT_LONG) try: json.loads(FORMAT_LONG) st.session_state['json_format'] = True except json.JSONDecodeError as e: st.error(f"Invalid JSON format: {str(e)}") st.session_state['json_format'] = False st.stop() # Stop execution if JSON is invalid # If the DataFrame is successfully created, allow the user to select a column col_unparsed = st.selectbox("Select column corresponding to text", data.columns) if st.button("Parse Dataset") and st.session_state['json_format']: if API_KEY: # Only validate if the API key has changed if API_KEY != st.session_state['previous_api_key']: if is_api_key_valid(API_KEY): st.session_state['api_key_valid'] = True st.session_state['previous_api_key'] = API_KEY st.success("API key is valid!") else: st.session_state['api_key_valid'] = False st.error("Invalid API key. Please check and try again.") elif st.session_state['api_key_valid']: st.success("API key is valid!") if not API_KEY:# or not st.session_state['api_key_valid']: st.warning("Please enter your API key to proceed.") st.stop() selected_column_data = filtered_data[col_unparsed].tolist() st.session_state.result = selected_column_data with st.status("Parsing...", expanded=True) as stat: try: st.write("Parsing descriptions...") parser = CachedUAPParser(api_key=API_KEY, model='gpt-4o-mini', col=st.session_state.result) descriptions = st.session_state.result format_long = FORMAT_LONG parser.process_descriptions(descriptions, format_long) st.session_state['parsed_responses'] = parser.parse_responses() try: responses_df = parser.responses_to_df('sightingDetails', st.session_state['parsed_responses']) except Exception as e: st.warning(f"Error parsing with 'sightingDetails': {e}") responses_df = parser.responses_to_df(None, st.session_state['parsed_responses']) if not responses_df.empty: st.dataframe(responses_df) st.session_state['parsed_responses_df'] = responses_df.copy() stat.update(label="Parsing complete", state="complete", expanded=False) else: st.error("Failed to create DataFrame from parsed responses.") except Exception as e: st.error(f"An error occurred during parsing: {str(e)}") # Add download button for parsed data if st.session_state['parsed_responses'] is not None: json_str = download_json(st.session_state['parsed_responses']) st.download_button( label="Download Parsed Data as JSON", data=json_str, file_name="parsed_responses.json", mime="application/json" ) # Add button to convert cached data to DataFrame if st.button("Convert Cached Data to DataFrame"): convert_cached_data_to_df(st.session_state['parsed_responses']) if st.session_state['parsed_responses_df'] is not None: st.download_button( label="Save CSV", data=convert_df(st.session_state['parsed_responses_df']), file_name="uap_data.csv", mime="text/csv", ) #except Exception as e: # stat.update(label=f"Parsing failed: {e}", state="error") # st.write("Parsing descriptions...") # st.update_status("Parsing descriptions...")