Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import cohere | |
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from uap_analyzer import UAPParser, UAPAnalyzer, UAPVisualizer | |
# import ChartGen | |
# from ChartGen import ChartGPT | |
from Levenshtein import distance | |
from sklearn.model_selection import train_test_split | |
from sklearn.metrics import confusion_matrix | |
from stqdm import stqdm | |
stqdm.pandas() | |
import streamlit.components.v1 as components | |
from dateutil import parser | |
from sentence_transformers import SentenceTransformer | |
import torch | |
import squarify | |
import matplotlib.colors as mcolors | |
import textwrap | |
import datamapplot | |
import json | |
st.set_option('deprecation.showPyplotGlobalUse', False) | |
from pandas.api.types import ( | |
is_categorical_dtype, | |
is_datetime64_any_dtype, | |
is_numeric_dtype, | |
is_object_dtype, | |
) | |
def plot_treemap(df, column, top_n=32): | |
# Get the value counts and the top N labels | |
value_counts = df[column].value_counts() | |
top_labels = value_counts.iloc[:top_n].index | |
# Use np.where to replace all values not in the top N with 'Other' | |
revised_column = f'{column}_revised' | |
df[revised_column] = np.where(df[column].isin(top_labels), df[column], 'Other') | |
# Get the value counts including the 'Other' category | |
sizes = df[revised_column].value_counts().values | |
labels = df[revised_column].value_counts().index | |
# Get a gradient of colors | |
# colors = list(mcolors.TABLEAU_COLORS.values()) | |
n_colors = len(sizes) | |
colors = plt.cm.Oranges(np.linspace(0.3, 0.9, n_colors))[::-1] | |
# Get % of each category | |
percents = sizes / sizes.sum() | |
# Prepare labels with percentages | |
labels = [f'{label}\n {percent:.1%}' for label, percent in zip(labels, percents)] | |
fig, ax = plt.subplots(figsize=(20, 12)) | |
# Plot the treemap | |
squarify.plot(sizes=sizes, label=labels, alpha=0.7, pad=True, color=colors, text_kwargs={'fontsize': 10}) | |
ax = plt.gca() | |
# Iterate over text elements and rectangles (patches) in the axes for color adjustment | |
for text, rect in zip(ax.texts, ax.patches): | |
background_color = rect.get_facecolor() | |
r, g, b, _ = mcolors.to_rgba(background_color) | |
brightness = np.average([r, g, b]) | |
text.set_color('white' if brightness < 0.5 else 'black') | |
# Adjust font size based on rectangle's area and wrap long text | |
coef = 0.8 | |
font_size = np.sqrt(rect.get_width() * rect.get_height()) * coef | |
text.set_fontsize(font_size) | |
wrapped_text = textwrap.fill(text.get_text(), width=20) | |
text.set_text(wrapped_text) | |
plt.axis('off') | |
plt.gca().invert_yaxis() | |
plt.gcf().set_size_inches(20, 12) | |
fig.patch.set_alpha(0) | |
ax.patch.set_alpha(0) | |
return fig | |
def plot_hist(df, column, bins=10, kde=True): | |
fig, ax = plt.subplots(figsize=(12, 6)) | |
sns.histplot(data=df, x=column, kde=True, bins=bins,color='orange') | |
# set the ticks and frame in orange | |
ax.spines['bottom'].set_color('orange') | |
ax.spines['top'].set_color('orange') | |
ax.spines['right'].set_color('orange') | |
ax.spines['left'].set_color('orange') | |
ax.xaxis.label.set_color('orange') | |
ax.yaxis.label.set_color('orange') | |
ax.tick_params(axis='x', colors='orange') | |
ax.tick_params(axis='y', colors='orange') | |
ax.title.set_color('orange') | |
# Set transparent background | |
fig.patch.set_alpha(0) | |
ax.patch.set_alpha(0) | |
return fig | |
def plot_line(df, x_column, y_columns, figsize=(12, 10), color='orange', title=None, rolling_mean_value=2): | |
import matplotlib.cm as cm | |
# Sort the dataframe by the date column | |
df = df.sort_values(by=x_column) | |
# Calculate rolling mean for each y_column | |
if rolling_mean_value: | |
df[y_columns] = df[y_columns].rolling(len(df) // rolling_mean_value).mean() | |
# Create the plot | |
fig, ax = plt.subplots(figsize=figsize) | |
colors = cm.Oranges(np.linspace(0.2, 1, len(y_columns))) | |
# Plot each y_column as a separate line with a different color | |
for i, y_column in enumerate(y_columns): | |
df.plot(x=x_column, y=y_column, ax=ax, color=colors[i], label=y_column, linewidth=.5) | |
# Rotate x-axis labels | |
ax.set_xticklabels(ax.get_xticklabels(), rotation=30, ha='right') | |
# Format x_column as date if it is | |
if np.issubdtype(df[x_column].dtype, np.datetime64) or np.issubdtype(df[x_column].dtype, np.timedelta64): | |
df[x_column] = pd.to_datetime(df[x_column]).dt.date | |
# Set title, labels, and legend | |
ax.set_title(title or f'{", ".join(y_columns)} over {x_column}', color=color, fontweight='bold') | |
ax.set_xlabel(x_column, color=color) | |
ax.set_ylabel(', '.join(y_columns), color=color) | |
ax.spines['bottom'].set_color('orange') | |
ax.spines['top'].set_color('orange') | |
ax.spines['right'].set_color('orange') | |
ax.spines['left'].set_color('orange') | |
ax.xaxis.label.set_color('orange') | |
ax.yaxis.label.set_color('orange') | |
ax.tick_params(axis='x', colors='orange') | |
ax.tick_params(axis='y', colors='orange') | |
ax.title.set_color('orange') | |
ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange') | |
# Remove background | |
fig.patch.set_alpha(0) | |
ax.patch.set_alpha(0) | |
return fig | |
def plot_bar(df, x_column, y_column, figsize=(12, 10), color='orange', title=None, rotation=45): | |
fig, ax = plt.subplots(figsize=figsize) | |
sns.barplot(data=df, x=x_column, y=y_column, color=color, ax=ax) | |
ax.set_title(title if title else f'{y_column} by {x_column}', color=color, fontweight='bold') | |
ax.set_xlabel(x_column, color=color) | |
ax.set_ylabel(y_column, color=color) | |
ax.tick_params(axis='x', colors=color) | |
ax.tick_params(axis='y', colors=color) | |
plt.xticks(rotation=rotation) | |
# Remove background | |
fig.patch.set_alpha(0) | |
ax.patch.set_alpha(0) | |
ax.spines['bottom'].set_color('orange') | |
ax.spines['top'].set_color('orange') | |
ax.spines['right'].set_color('orange') | |
ax.spines['left'].set_color('orange') | |
ax.xaxis.label.set_color('orange') | |
ax.yaxis.label.set_color('orange') | |
ax.tick_params(axis='x', colors='orange') | |
ax.tick_params(axis='y', colors='orange') | |
ax.title.set_color('orange') | |
ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange') | |
return fig | |
def plot_grouped_bar(df, x_columns, y_column, figsize=(12, 10), colors=None, title=None): | |
fig, ax = plt.subplots(figsize=figsize) | |
width = 0.8 / len(x_columns) # the width of the bars | |
x = np.arange(len(df)) # the label locations | |
for i, x_column in enumerate(x_columns): | |
sns.barplot(data=df, x=x, y=y_column, color=colors[i] if colors else None, ax=ax, width=width, label=x_column) | |
x += width # add the width of the bar to the x position for the next bar | |
ax.set_title(title if title else f'{y_column} by {", ".join(x_columns)}', color='orange', fontweight='bold') | |
ax.set_xlabel('Groups', color='orange') | |
ax.set_ylabel(y_column, color='orange') | |
ax.set_xticks(x - width * len(x_columns) / 2) | |
ax.set_xticklabels(df.index) | |
ax.tick_params(axis='x', colors='orange') | |
ax.tick_params(axis='y', colors='orange') | |
# Remove background | |
fig.patch.set_alpha(0) | |
ax.patch.set_alpha(0) | |
ax.spines['bottom'].set_color('orange') | |
ax.spines['top'].set_color('orange') | |
ax.spines['right'].set_color('orange') | |
ax.spines['left'].set_color('orange') | |
ax.xaxis.label.set_color('orange') | |
ax.yaxis.label.set_color('orange') | |
ax.title.set_color('orange') | |
ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange') | |
return fig | |
def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame: | |
""" | |
Adds a UI on top of a dataframe to let viewers filter columns | |
Args: | |
df (pd.DataFrame): Original dataframe | |
Returns: | |
pd.DataFrame: Filtered dataframe | |
""" | |
title_font = "Arial" | |
body_font = "Arial" | |
title_size = 32 | |
colors = ["red", "green", "blue"] | |
interpretation = False | |
extract_docx = False | |
title = "My Chart" | |
regex = ".*" | |
img_path = 'default_image.png' | |
df_ = df.copy() | |
#modification_container = st.container() | |
#with modification_container: | |
to_filter_columns = st.multiselect("Filter dataframe on", df_.columns) | |
date_column = None | |
filtered_columns = [] | |
for column in to_filter_columns: | |
left, right = st.columns((1, 20)) | |
# Treat columns with < 200 unique values as categorical if not date or numeric | |
if is_categorical_dtype(df_[column]) or (df_[column].nunique() < 120 and not is_datetime64_any_dtype(df_[column]) and not is_numeric_dtype(df_[column])): | |
user_cat_input = right.multiselect( | |
f"Values for {column}", | |
df_[column].value_counts().index.tolist(), | |
default=list(df_[column].value_counts().index) | |
) | |
df_ = df_[df_[column].isin(user_cat_input)] | |
filtered_columns.append(column) | |
with st.status(f"Category Distribution: {column}", expanded=False) as stat: | |
st.pyplot(plot_treemap(df_, column)) | |
elif is_numeric_dtype(df_[column]): | |
_min = float(df_[column].min()) | |
_max = float(df_[column].max()) | |
step = (_max - _min) / 100 | |
user_num_input = right.slider( | |
f"Values for {column}", | |
min_value=_min, | |
max_value=_max, | |
value=(_min, _max), | |
step=step, | |
) | |
df_ = df_[df_[column].between(*user_num_input)] | |
filtered_columns.append(column) | |
# Chart_GPT = ChartGPT(df_, title_font, body_font, title_size, | |
# colors, interpretation, extract_docx, img_path) | |
with st.status(f"Numerical Distribution: {column}", expanded=False) as stat_: | |
st.pyplot(plot_hist(df_, column, bins=int(round(len(df_[column].unique())-1)/2))) | |
elif is_object_dtype(df_[column]): | |
try: | |
df_[column] = pd.to_datetime(df_[column], infer_datetime_format=True, errors='coerce') | |
except Exception: | |
try: | |
df_[column] = df_[column].apply(parser.parse) | |
except Exception: | |
pass | |
if is_datetime64_any_dtype(df_[column]): | |
df_[column] = df_[column].dt.tz_localize(None) | |
min_date = df_[column].min().date() | |
max_date = df_[column].max().date() | |
user_date_input = right.date_input( | |
f"Values for {column}", | |
value=(min_date, max_date), | |
min_value=min_date, | |
max_value=max_date, | |
) | |
# if len(user_date_input) == 2: | |
# start_date, end_date = user_date_input | |
# df_ = df_.loc[df_[column].dt.date.between(start_date, end_date)] | |
if len(user_date_input) == 2: | |
user_date_input = tuple(map(pd.to_datetime, user_date_input)) | |
start_date, end_date = user_date_input | |
# Determine the most appropriate time unit for plot | |
time_units = { | |
'year': df_[column].dt.year, | |
'month': df_[column].dt.to_period('M'), | |
'day': df_[column].dt.date | |
} | |
unique_counts = {unit: col.nunique() for unit, col in time_units.items()} | |
closest_to_36 = min(unique_counts, key=lambda k: abs(unique_counts[k] - 36)) | |
# Group by the most appropriate time unit and count occurrences | |
grouped = df_.groupby(time_units[closest_to_36]).size().reset_index(name='count') | |
grouped.columns = [column, 'count'] | |
# Create a complete date range | |
if closest_to_36 == 'year': | |
date_range = pd.date_range(start=f"{start_date.year}-01-01", end=f"{end_date.year}-12-31", freq='YS') | |
elif closest_to_36 == 'month': | |
date_range = pd.date_range(start=start_date.replace(day=1), end=end_date + pd.offsets.MonthEnd(0), freq='MS') | |
else: # day | |
date_range = pd.date_range(start=start_date, end=end_date, freq='D') | |
# Create a DataFrame with the complete date range | |
complete_range = pd.DataFrame({column: date_range}) | |
# Convert the date column to the appropriate format based on closest_to_36 | |
if closest_to_36 == 'year': | |
complete_range[column] = complete_range[column].dt.year | |
elif closest_to_36 == 'month': | |
complete_range[column] = complete_range[column].dt.to_period('M') | |
# Merge the complete range with the grouped data | |
final_data = pd.merge(complete_range, grouped, on=column, how='left').fillna(0) | |
with st.status(f"Date Distributions: {column}", expanded=False) as stat: | |
try: | |
st.pyplot(plot_bar(final_data, column, 'count')) | |
except Exception as e: | |
st.error(f"Error plotting bar chart: {e}") | |
df_ = df_.loc[df_[column].between(start_date, end_date)] | |
date_column = column | |
if date_column and filtered_columns: | |
numeric_columns = [col for col in filtered_columns if is_numeric_dtype(df_[col])] | |
if numeric_columns: | |
fig = plot_line(df_, date_column, numeric_columns) | |
with st.status(f"Date Numerical Distributions: {column}", expanded=False) as stat: | |
try: | |
st.pyplot(fig) | |
except Exception as e: | |
st.error(f"Error plotting line chart: {e}") | |
pass # now to deal with categorical columns | |
categorical_columns = [col for col in filtered_columns if is_categorical_dtype(df_[col])] | |
if categorical_columns: | |
fig2 = plot_grouped_bar(df_, categorical_columns, date_column) | |
with st.status(f"Date Categorical Distributions: {column}", expanded=False) as sta: | |
try: | |
st.pyplot(fig2) | |
except Exception as e: | |
st.error(f"Error plotting bar chart: {e}") | |
else: | |
user_text_input = right.text_input( | |
f"Substring or regex in {column}", | |
) | |
if user_text_input: | |
df_ = df_[df_[column].astype(str).str.contains(user_text_input)] | |
# write len of df after filtering with % of original | |
st.write(f"{len(df_)} rows ({len(df_) / len(df) * 100:.2f}%)") | |
return df_ | |
# Initialize session state | |
if 'analyzers' not in st.session_state: | |
st.session_state['analyzers'] = [] | |
if 'col_names' not in st.session_state: | |
st.session_state['col_names'] = [] | |
if 'clusters' not in st.session_state: | |
st.session_state['clusters'] = {} | |
if 'new_data' not in st.session_state: | |
st.session_state['new_data'] = pd.DataFrame() | |
if 'dataset' not in st.session_state: | |
st.session_state['dataset'] = pd.DataFrame() | |
if 'data_processed' not in st.session_state: | |
st.session_state['data_processed'] = False | |
if 'stage' not in st.session_state: | |
st.session_state['stage'] = 0 | |
if 'filtered_data' not in st.session_state: | |
st.session_state['filtered_data'] = None | |
if 'gemini_answer' not in st.session_state: | |
st.session_state['gemini_answer'] = None | |
if 'parsed_responses' not in st.session_state: | |
st.session_state['parsed_responses'] = None | |
if 'json_format' not in st.session_state: | |
st.session_state['json_format'] = None | |
if 'api_key_valid' not in st.session_state: | |
st.session_state['api_key_valid'] = False | |
if 'previous_api_key' not in st.session_state: | |
st.session_state['previous_api_key'] = None | |
OPENAI_KEY = st.secrets["OPENAI_KEY"] | |
GEMINI_KEY = st.secrets["GEMINI_KEY"] | |
COHERE_KEY = st.secrets["COHERE_KEY"] | |
def load_data(file_path, key='df'): | |
return pd.read_hdf(file_path, key=key) | |
datasett = st.file_uploader("Upload Raw DataFrame", type=["csv", "xlsx"]) | |
if datasett is not None: | |
try: | |
data = pd.read_csv(datasett) if datasett.type == "text/csv" else pd.read_excel(datasett) | |
filtered_data = filter_dataframe(data) | |
st.session_state['parsed_responses'] = filtered_data | |
st.dataframe(filtered_data) | |
except Exception as e: | |
st.error(f"An error occurred while reading the file: {e}") | |
col1, col2 = st.columns(2) | |
with col1: | |
columns_to_query = st.multiselect( | |
label='Select columns to analyze', | |
options=st.session_state['parsed_responses'].columns) | |
with col2: | |
COHERE_KEY = st.text_input('Cohere APIs Key', COHERE_KEY, type='password', help="Enter your Cohere API key") | |
question = st.text_input("Ask a question") | |
if st.session_state['parsed_responses'] is not None and question and COHERE_KEY: | |
co = cohere.Client(api_key = COHERE_KEY) | |
documents = st.session_state['parsed_responses'][columns_to_query].to_dict('records') | |
json_documents = [json.dumps(doc) for doc in documents] | |
try: | |
results = co.rerank( | |
model="rerank-english-v3.0", | |
query=question, | |
documents=json_documents, | |
top_n=5, | |
return_documents=True | |
) | |
st.subheader("Reranked Results:") | |
# Create a new dataframe with reranked results | |
reranked_indices = [result.index for result in results.results] | |
reranked_scores = [result.relevance_score for result in results.results] | |
reranked_df = st.session_state['parsed_responses'].iloc[reranked_indices].copy() | |
reranked_df['relevance_score'] = reranked_scores | |
reranked_df['rank'] = range(1, len(reranked_indices) + 1) | |
# Set the new index to be the rank | |
reranked_df.set_index('rank', inplace=True) | |
# Display the reranked dataframe | |
st.dataframe(reranked_df) | |
# markdown format | |
#for idx, result in enumerate(results.results, 1): | |
# st.write(f"Result {idx}:") | |
# st.write(f"Index: {result.index}") | |
# st.write(f"Relevance Score: {result.relevance_score}") | |
# st.write(f"Document: {json.loads(json_documents[result.index])}") | |
# st.write("---") | |
except Exception as e: | |
st.error(f"An error occurred during reranking: {e}") | |