UAP-Data-Analysis-Tool / magnetic.py
Ashoka74's picture
Update magnetic.py
46b302d verified
import math
import pandas as pd
import numpy as np
import json
import requests
import datetime
from datetime import timedelta
from PIL import Image
# alternative to PIL
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
import matplotlib.dates as mdates
import seaborn as sns
from IPython.display import Image as image_display
path = os.getcwd()
from fastdtw import fastdtw
from scipy.spatial.distance import euclidean
from IPython.display import display
from dateutil import parser
from Levenshtein import distance
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from stqdm import stqdm
stqdm.pandas()
import streamlit.components.v1 as components
from dateutil import parser
from sentence_transformers import SentenceTransformer
import torch
import squarify
import matplotlib.colors as mcolors
import textwrap
import datamapplot
import streamlit as st
if 'form_submitted' not in st.session_state:
st.session_state['form_submitted'] = False
st.title('Magnetic Correlations Dashboard')
st.set_option('deprecation.showPyplotGlobalUse', False)
from pandas.api.types import (
is_categorical_dtype,
is_datetime64_any_dtype,
is_numeric_dtype,
is_object_dtype,
)
def plot_treemap(df, column, top_n=32):
# Get the value counts and the top N labels
value_counts = df[column].value_counts()
top_labels = value_counts.iloc[:top_n].index
# Use np.where to replace all values not in the top N with 'Other'
revised_column = f'{column}_revised'
df[revised_column] = np.where(df[column].isin(top_labels), df[column], 'Other')
# Get the value counts including the 'Other' category
sizes = df[revised_column].value_counts().values
labels = df[revised_column].value_counts().index
# Get a gradient of colors
# colors = list(mcolors.TABLEAU_COLORS.values())
n_colors = len(sizes)
colors = plt.cm.Oranges(np.linspace(0.3, 0.9, n_colors))[::-1]
# Get % of each category
percents = sizes / sizes.sum()
# Prepare labels with percentages
labels = [f'{label}\n {percent:.1%}' for label, percent in zip(labels, percents)]
fig, ax = plt.subplots(figsize=(20, 12))
# Plot the treemap
squarify.plot(sizes=sizes, label=labels, alpha=0.7, pad=True, color=colors, text_kwargs={'fontsize': 10})
ax = plt.gca()
# Iterate over text elements and rectangles (patches) in the axes for color adjustment
for text, rect in zip(ax.texts, ax.patches):
background_color = rect.get_facecolor()
r, g, b, _ = mcolors.to_rgba(background_color)
brightness = np.average([r, g, b])
text.set_color('white' if brightness < 0.5 else 'black')
def plot_hist(df, column, bins=10, kde=True):
fig, ax = plt.subplots(figsize=(12, 6))
sns.histplot(data=df, x=column, kde=True, bins=bins,color='orange')
# set the ticks and frame in orange
ax.spines['bottom'].set_color('orange')
ax.spines['top'].set_color('orange')
ax.spines['right'].set_color('orange')
ax.spines['left'].set_color('orange')
ax.xaxis.label.set_color('orange')
ax.yaxis.label.set_color('orange')
ax.tick_params(axis='x', colors='orange')
ax.tick_params(axis='y', colors='orange')
ax.title.set_color('orange')
# Set transparent background
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
return fig
def plot_line(df, x_column, y_columns, figsize=(12, 10), color='orange', title=None, rolling_mean_value=2):
import matplotlib.cm as cm
# Sort the dataframe by the date column
df = df.sort_values(by=x_column)
# Calculate rolling mean for each y_column
if rolling_mean_value:
df[y_columns] = df[y_columns].rolling(len(df) // rolling_mean_value).mean()
# Create the plot
fig, ax = plt.subplots(figsize=figsize)
colors = cm.Oranges(np.linspace(0.2, 1, len(y_columns)))
# Plot each y_column as a separate line with a different color
for i, y_column in enumerate(y_columns):
df.plot(x=x_column, y=y_column, ax=ax, color=colors[i], label=y_column, linewidth=.5)
# Rotate x-axis labels
ax.set_xticklabels(ax.get_xticklabels(), rotation=30, ha='right')
# Format x_column as date if it is
if np.issubdtype(df[x_column].dtype, np.datetime64) or np.issubdtype(df[x_column].dtype, np.timedelta64):
df[x_column] = pd.to_datetime(df[x_column]).dt.date
# Set title, labels, and legend
ax.set_title(title or f'{", ".join(y_columns)} over {x_column}', color=color, fontweight='bold')
ax.set_xlabel(x_column, color=color)
ax.set_ylabel(', '.join(y_columns), color=color)
ax.spines['bottom'].set_color('orange')
ax.spines['top'].set_color('orange')
ax.spines['right'].set_color('orange')
ax.spines['left'].set_color('orange')
ax.xaxis.label.set_color('orange')
ax.yaxis.label.set_color('orange')
ax.tick_params(axis='x', colors='orange')
ax.tick_params(axis='y', colors='orange')
ax.title.set_color('orange')
ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange')
# Remove background
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
return fig
def plot_bar(df, x_column, y_column, figsize=(12, 10), color='orange', title=None, rotation=45):
fig, ax = plt.subplots(figsize=figsize)
sns.barplot(data=df, x=x_column, y=y_column, color=color, ax=ax)
ax.set_title(title if title else f'{y_column} by {x_column}', color=color, fontweight='bold')
ax.set_xlabel(x_column, color=color)
ax.set_ylabel(y_column, color=color)
ax.tick_params(axis='x', colors=color)
ax.tick_params(axis='y', colors=color)
plt.xticks(rotation=rotation)
# Remove background
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
ax.spines['bottom'].set_color('orange')
ax.spines['top'].set_color('orange')
ax.spines['right'].set_color('orange')
ax.spines['left'].set_color('orange')
ax.xaxis.label.set_color('orange')
ax.yaxis.label.set_color('orange')
ax.tick_params(axis='x', colors='orange')
ax.tick_params(axis='y', colors='orange')
ax.title.set_color('orange')
ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange')
return fig
def plot_grouped_bar(df, x_columns, y_column, figsize=(12, 10), colors=None, title=None):
fig, ax = plt.subplots(figsize=figsize)
width = 0.8 / len(x_columns) # the width of the bars
x = np.arange(len(df)) # the label locations
for i, x_column in enumerate(x_columns):
sns.barplot(data=df, x=x, y=y_column, color=colors[i] if colors else None, ax=ax, width=width, label=x_column)
x += width # add the width of the bar to the x position for the next bar
ax.set_title(title if title else f'{y_column} by {", ".join(x_columns)}', color='orange', fontweight='bold')
ax.set_xlabel('Groups', color='orange')
ax.set_ylabel(y_column, color='orange')
ax.set_xticks(x - width * len(x_columns) / 2)
ax.set_xticklabels(df.index)
ax.tick_params(axis='x', colors='orange')
ax.tick_params(axis='y', colors='orange')
# Remove background
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
ax.spines['bottom'].set_color('orange')
ax.spines['top'].set_color('orange')
ax.spines['right'].set_color('orange')
ax.spines['left'].set_color('orange')
ax.xaxis.label.set_color('orange')
ax.yaxis.label.set_color('orange')
ax.title.set_color('orange')
ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange')
return fig
def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame:
"""
Adds a UI on top of a dataframe to let viewers filter columns
Args:
df (pd.DataFrame): Original dataframe
Returns:
pd.DataFrame: Filtered dataframe
"""
title_font = "Arial"
body_font = "Arial"
title_size = 32
colors = ["red", "green", "blue"]
interpretation = False
extract_docx = False
title = "My Chart"
regex = ".*"
img_path = 'default_image.png'
#try:
# modify = st.checkbox("Add filters on raw data")
#except:
# try:
# modify = st.checkbox("Add filters on processed data")
# except:
# try:
# modify = st.checkbox("Add filters on parsed data")
# except:
# pass
#if not modify:
# return df
df_ = df.copy()
# Try to convert datetimes into a standard format (datetime, no timezone)
#modification_container = st.container()
#with modification_container:
to_filter_columns = st.multiselect("Filter dataframe on", df_.columns)
date_column = None
filtered_columns = []
for column in to_filter_columns:
left, right = st.columns((1, 20))
# Treat columns with < 200 unique values as categorical if not date or numeric
if is_categorical_dtype(df_[column]) or (df_[column].nunique() < 120 and not is_datetime64_any_dtype(df_[column]) and not is_numeric_dtype(df_[column])):
user_cat_input = right.multiselect(
f"Values for {column}",
df_[column].value_counts().index.tolist(),
default=list(df_[column].value_counts().index)
)
df_ = df_[df_[column].isin(user_cat_input)]
filtered_columns.append(column)
with st.status(f"Category Distribution: {column}", expanded=False) as stat:
st.pyplot(plot_treemap(df_, column))
elif is_numeric_dtype(df_[column]):
_min = float(df_[column].min())
_max = float(df_[column].max())
step = (_max - _min) / 100
user_num_input = right.slider(
f"Values for {column}",
min_value=_min,
max_value=_max,
value=(_min, _max),
step=step,
)
df_ = df_[df_[column].between(*user_num_input)]
filtered_columns.append(column)
# Chart_GPT = ChartGPT(df_, title_font, body_font, title_size,
# colors, interpretation, extract_docx, img_path)
with st.status(f"Numerical Distribution: {column}", expanded=False) as stat_:
st.pyplot(plot_hist(df_, column, bins=int(round(len(df_[column].unique())-1)/2)))
elif is_object_dtype(df_[column]):
try:
df_[column] = pd.to_datetime(df_[column], infer_datetime_format=True, errors='coerce')
except Exception:
try:
df_[column] = df_[column].apply(parser.parse)
except Exception:
pass
if is_datetime64_any_dtype(df_[column]):
df_[column] = df_[column].dt.tz_localize(None)
min_date = df_[column].min().date()
max_date = df_[column].max().date()
user_date_input = right.date_input(
f"Values for {column}",
value=(min_date, max_date),
min_value=min_date,
max_value=max_date,
)
if len(user_date_input) == 2:
user_date_input = tuple(map(pd.to_datetime, user_date_input))
start_date, end_date = user_date_input
# Determine the most appropriate time unit for plot
time_units = {
'year': df_[column].dt.year,
'month': df_[column].dt.to_period('M'),
'day': df_[column].dt.date
}
unique_counts = {unit: col.nunique() for unit, col in time_units.items()}
closest_to_36 = min(unique_counts, key=lambda k: abs(unique_counts[k] - 36))
# Group by the most appropriate time unit and count occurrences
grouped = df_.groupby(time_units[closest_to_36]).size().reset_index(name='count')
grouped.columns = [column, 'count']
# Create a complete date range
if closest_to_36 == 'year':
date_range = pd.date_range(start=f"{start_date.year}-01-01", end=f"{end_date.year}-12-31", freq='YS')
elif closest_to_36 == 'month':
date_range = pd.date_range(start=start_date.replace(day=1), end=end_date + pd.offsets.MonthEnd(0), freq='MS')
else: # day
date_range = pd.date_range(start=start_date, end=end_date, freq='D')
# Create a DataFrame with the complete date range
complete_range = pd.DataFrame({column: date_range})
# Convert the date column to the appropriate format based on closest_to_36
if closest_to_36 == 'year':
complete_range[column] = complete_range[column].dt.year
elif closest_to_36 == 'month':
complete_range[column] = complete_range[column].dt.to_period('M')
# Merge the complete range with the grouped data
final_data = pd.merge(complete_range, grouped, on=column, how='left').fillna(0)
with st.status(f"Date Distributions: {column}", expanded=False) as stat:
try:
st.pyplot(plot_bar(final_data, column, 'count'))
except Exception as e:
st.error(f"Error plotting bar chart: {e}")
df_ = df_.loc[df_[column].between(start_date, end_date)]
date_column = column
if date_column and filtered_columns:
numeric_columns = [col for col in filtered_columns if is_numeric_dtype(df_[col])]
if numeric_columns:
fig = plot_line(df_, date_column, numeric_columns)
#st.pyplot(fig)
# now to deal with categorical columns
categorical_columns = [col for col in filtered_columns if is_categorical_dtype(df_[col])]
if categorical_columns:
fig2 = plot_bar(df_, date_column, categorical_columns[0])
#st.pyplot(fig2)
with st.status(f"Date Distribution: {column}", expanded=False) as stat:
try:
st.pyplot(fig)
except Exception as e:
st.error(f"Error plotting line chart: {e}")
pass
try:
st.pyplot(fig2)
except Exception as e:
st.error(f"Error plotting bar chart: {e}")
else:
user_text_input = right.text_input(
f"Substring or regex in {column}",
)
if user_text_input:
df_ = df_[df_[column].astype(str).str.contains(user_text_input)]
# write len of df after filtering with % of original
st.write(f"{len(df_)} rows ({len(df_) / len(df) * 100:.2f}%)")
return df_
def get_stations():
base_url = 'https://imag-data.bgs.ac.uk:/GIN_V1/GINServices?Request=GetCapabilities&format=json'
response = requests.get(base_url)
data = response.json()
dataframe_stations = pd.DataFrame.from_dict(data['ObservatoryList'])
return dataframe_stations
def get_haversine_distance(lat1, lon1, lat2, lon2):
R = 6371
dlat = math.radians(lat2 - lat1)
dlon = math.radians(lon2 - lon1)
a = math.sin(dlat/2) * math.sin(dlat/2) + math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) * math.sin(dlon/2) * math.sin(dlon/2)
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1-a))
d = R * c
return d
def compare_stations(test_lat_lon, data_table, distance=1000, closest=False):
table_updated = pd.DataFrame()
distances = dict()
for lat,lon,names in data_table[['Latitude', 'Longitude', 'Name']].values:
harv_distance = get_haversine_distance(test_lat_lon[0], test_lat_lon[1], lat, lon)
if harv_distance < distance:
#print(f"Station {names} is at {round(harv_distance,2)} km from the test point")
table_updated = pd.concat([table_updated, data_table[data_table['Name'] == names]])
distances[names] = harv_distance
if closest:
closest_station = min(distances, key=distances.get)
#print(f"The closest station is {closest_station} at {round(distances[closest_station],2)} km")
table_updated = data_table[data_table['Name'] == closest_station]
table_updated['Distance'] = distances[closest_station]
return table_updated
def get_data(IagaCode, start_date, end_date):
try:
start_date_ = datetime.datetime.strptime(start_date, '%Y-%m-%d')
except ValueError as e:
print(f"Error: {e}")
start_date_ = pd.to_datetime(start_date)
try:
end_date_ = datetime.datetime.strptime(end_date, '%Y-%m-%d')
except ValueError as e:
print(f"Error: {e}")
end_date_ = pd.to_datetime(end_date)
duration = end_date_ - start_date_
# Define the parameters for the request
params = {
'Request': 'GetData',
'format': 'PNG',
'testObsys': '0',
'observatoryIagaCode': IagaCode,
'samplesPerDay': 'minute',
'publicationState': 'Best available',
'dataStartDate': start_date,
# make substraction
'dataDuration': duration.days,
'traceList': '1234',
'colourTraces': 'true',
'pictureSize': 'Automatic',
'dataScale': 'Automatic',
'pdfSize': '21,29.7',
}
base_url_json = 'https://imag-data.bgs.ac.uk:/GIN_V1/GINServices?Request=GetData&format=json'
#base_url_img = 'https://imag-data.bgs.ac.uk:/GIN_V1/GINServices?Request=GetData&format=png'
for base_url in [base_url_json]:#, base_url_img]:
response = requests.get(base_url, params=params)
if response.status_code == 200:
content_type = response.headers.get('Content-Type')
if 'image' in content_type:
# f"custom_plot_{new_dataset.iloc[0]['IagaCode']}_{str_date.replace(':', '_')}.png"
# output_image_path = "plot_image.png"
# with open(output_image_path, 'wb') as file:
# file.write(response.content)
# print(f"Image successfully saved as {output_image_path}")
# # Display the image
# img = mpimg.imread(output_image_path)
# plt.imshow(img)
# plt.axis('off') # Hide axes
# plt.show()
# img_answer = Image.open(output_image_path)
img_answer = None
else:
print(f"Unexpected content type: {content_type}")
#print("Response content:")
#print(response.content.decode('utf-8')) # Attempt to print response as text
# return json
answer = response.json()
else:
print(f"Failed to retrieve data. HTTP Status code: {response.status_code}")
print("Response content:")
print(response.content.decode('utf-8'))
return answer#, img_answer
# def get_data(IagaCode, start_date, end_date):
# # Convert dates to datetime
# try:
# start_date_ = pd.to_datetime(start_date)
# end_date_ = pd.to_datetime(end_date)
# except ValueError as e:
# print(f"Error: {e}")
# return None, None
# duration = (end_date_ - start_date_).days
# # Define the parameters for the request
# params = {
# 'Request': 'GetData',
# 'format': 'json',
# 'testObsys': '0',
# 'observatoryIagaCode': IagaCode,
# 'samplesPerDay': 'minute',
# 'publicationState': 'Best available',
# 'dataStartDate': start_date_.strftime('%Y-%m-%d'),
# 'dataDuration': duration,
# 'traceList': '1234',
# 'colourTraces': 'true',
# 'pictureSize': 'Automatic',
# 'dataScale': 'Automatic',
# 'pdfSize': '21,29.7',
# }
# base_url_json = 'https://imag-data.bgs.ac.uk:/GIN_V1/GINServices?Request=GetData&format=json'
# base_url_img = 'https://imag-data.bgs.ac.uk:/GIN_V1/GINServices?Request=GetData&format=png'
# try:
# # Request JSON data
# response_json = requests.get(base_url_json, params=params)
# response_json.raise_for_status() # Raises an error for bad status codes
# data = response_json.json()
# # Request Image
# params['format'] = 'png'
# response_img = requests.get(base_url_img, params=params)
# response_img.raise_for_status()
# # Save and display image if response is successful
# if 'image' in response_img.headers.get('Content-Type'):
# output_image_path = "plot_image.png"
# with open(output_image_path, 'wb') as file:
# file.write(response_img.content)
# print(f"Image successfully saved as {output_image_path}")
# img = mpimg.imread(output_image_path)
# plt.imshow(img)
# plt.axis('off')
# plt.show()
# img_answer = Image.open(output_image_path)
# else:
# img_answer = None
# return data, img_answer
# except requests.RequestException as e:
# print(f"Request failed: {e}")
# return None, None
# except ValueError as e:
# print(f"JSON decode error: {e}")
# return None, None
def clean_uap_data(dataset, lat, lon, date):
# Assuming 'nuforc' is already defined
processed = dataset[dataset[[lat, lon, date]].notnull().all(axis=1)]
# Converting 'Lat' and 'Long' columns to floats, handling errors
processed[lat] = pd.to_numeric(processed[lat], errors='coerce')
processed[lon] = pd.to_numeric(processed[lon], errors='coerce')
# if processed[date].min() < pd.to_datetime('1677-09-22'):
# processed.loc[processed[date] < pd.to_datetime('1677-09-22'), 'corrected_date'] = pd.to_datetime('1677-09-22 00:00:00')
procesed = processed[processed[date] >= '1677-09-22']
# convert date to str
#processed[date] = processed[date].astype(str)
# Dropping rows where 'Lat' or 'Long' conversion failed (i.e., became NaN)
processed = processed.dropna(subset=[lat, lon])
return processed
def plot_overlapped_timeseries(data_list, event_times, window_hours=12, save_path=None):
fig, axs = plt.subplots(4, 1, figsize=(12, 16), sharex=True)
fig.patch.set_alpha(0) # Make figure background transparent
components = ['X', 'Y', 'Z', 'S']
colors = ['red', 'green', 'blue', 'black']
for i, component in enumerate(components):
axs[i].patch.set_alpha(0) # Make subplot background transparent
axs[i].set_ylabel(component, color='orange')
axs[i].grid(True, color='orange', alpha=0.3)
for spine in axs[i].spines.values():
spine.set_color('orange')
axs[i].tick_params(axis='both', colors='orange') # Change tick color
axs[i].set_title(f'{component}', color='orange')
axs[i].set_xlabel('Time Difference from Event (hours)', color='orange')
for j, (df, event_time) in enumerate(zip(data_list, event_times)):
# Convert datetime column to UTC if it has timezone info, otherwise assume it's UTC
df['datetime'] = pd.to_datetime(df['datetime']).dt.tz_localize(None)
# Convert event_time to UTC if it has timezone info, otherwise assume it's UTC
event_time = pd.to_datetime(event_time).tz_localize(None)
# Calculate time difference from event
df['time_diff'] = (df['datetime'] - event_time).dt.total_seconds() / 3600 # Convert to hours
# Filter data within the specified window
df_window = df[(df['time_diff'] >= -window_hours) & (df['time_diff'] <= window_hours)]
# normalize component data
df_window[component] = (df_window[component] - df_window[component].mean()) / df_window[component].std()
axs[i].plot(df_window['time_diff'], df_window[component], color=colors[i], alpha=0.7, label=f'Event {j+1}', linewidth=1)
axs[i].axvline(x=0, color='red', linewidth=2, linestyle='--', label='Event Time')
axs[i].set_xlim(-window_hours, window_hours)
#axs[i].legend(loc='upper left', bbox_to_anchor=(1, 1))
axs[-1].set_xlabel('Hours from Event', color='orange')
fig.suptitle('Overlapped Time Series of Components', fontsize=16, color='orange')
plt.tight_layout()
plt.subplots_adjust(top=0.95, right=0.85)
if save_path:
fig.savefig(save_path, transparent=True, bbox_inches='tight')
plt.close(fig)
return save_path
else:
return fig
def plot_average_timeseries(data_list, event_times, window_hours=12, save_path=None):
fig, axs = plt.subplots(4, 1, figsize=(12, 16), sharex=True)
fig.patch.set_alpha(0) # Make figure background transparent
components = ['X', 'Y', 'Z', 'S']
colors = ['red', 'green', 'blue', 'black']
for i, component in enumerate(components):
axs[i].patch.set_alpha(0)
axs[i].set_ylabel(component, color='orange')
axs[i].grid(True, color='orange', alpha=0.3)
for spine in axs[i].spines.values():
spine.set_color('orange')
axs[i].tick_params(axis='both', colors='orange')
all_data = []
time_diffs = []
for j, (df, event_time) in enumerate(zip(data_list, event_times)):
# Convert datetime column to UTC if it has timezone info, otherwise assume it's UTC
df['datetime'] = pd.to_datetime(df['datetime']).dt.tz_localize(None)
# Convert event_time to UTC if it has timezone info, otherwise assume it's UTC
event_time = pd.to_datetime(event_time).tz_localize(None)
# Calculate time difference from event
df['time_diff'] = (df['datetime'] - event_time).dt.total_seconds() / 3600 # Convert to hours
# Filter data within the specified window
df_window = df[(df['time_diff'] >= -window_hours) & (df['time_diff'] <= window_hours)]
# Normalize component data
df_window[component] = (df_window[component] - df_window[component].mean())# / df_window[component].std()
all_data.append(df_window[component].values)
time_diffs.append(df_window['time_diff'].values)
# Calculate average and standard deviation
try:
avg_data = np.mean(all_data, axis=0)
except:
avg_data = np.zeros_like(all_data[0])
try:
std_data = np.std(all_data, axis=0)
except:
std_data = np.zeros_like(avg_data)
axs[-1].set_xlabel('Hours from Event', color='orange')
fig.suptitle('Average Time Series of Components', fontsize=16, color='orange')
# Plot average line
axs[i].plot(time_diffs[0], avg_data, color=colors[i], label='Average')
# Plot standard deviation as shaded region
try:
axs[i].fill_between(time_diffs[0], avg_data - std_data, avg_data + std_data, color=colors[i], alpha=0.2)
except:
pass
axs[i].axvline(x=0, color='red', linewidth=2, linestyle='--', label='Event Time')
axs[i].set_xlim(-window_hours, window_hours)
# orange frame, orange label legend
axs[i].legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange')
plt.tight_layout()
plt.subplots_adjust(top=0.95, right=0.85)
if save_path:
fig.savefig(save_path, transparent=True, bbox_inches='tight')
plt.close(fig)
return save_path
else:
return fig
def align_series(reference, series):
reference = reference.flatten()
series = series.flatten()
_, path = fastdtw(reference, series, dist=euclidean)
aligned = np.zeros(len(reference))
for ref_idx, series_idx in path:
aligned[ref_idx] = series[series_idx]
return aligned
def plot_average_timeseries_with_dtw(data_list, event_times, window_hours=12, save_path=None):
fig, axs = plt.subplots(4, 1, figsize=(12, 16), sharex=True)
fig.patch.set_alpha(0) # Make figure background transparent
components = ['X', 'Y', 'Z', 'S']
colors = ['red', 'green', 'blue', 'black']
fig.text(0.02, 0.5, 'Geomagnetic Variation (nT)', va='center', rotation='vertical', color='orange')
for i, component in enumerate(components):
axs[i].patch.set_alpha(0)
axs[i].set_ylabel(component, color='orange', rotation=90)
axs[i].grid(True, color='orange', alpha=0.3)
for spine in axs[i].spines.values():
spine.set_color('orange')
axs[i].tick_params(axis='both', colors='orange')
all_aligned_data = []
reference_df = None
for j, (df, event_time) in enumerate(zip(data_list, event_times)):
df['datetime'] = pd.to_datetime(df['datetime']).dt.tz_localize(None)
event_time = pd.to_datetime(event_time).tz_localize(None)
df['time_diff'] = (df['datetime'] - event_time).dt.total_seconds() / 3600
df_window = df[(df['time_diff'] >= -window_hours) & (df['time_diff'] <= window_hours)]
df_window[component] = (df_window[component] - df_window[component].mean())# / df_window[component].std()
if reference_df is None:
reference_df = df_window
all_aligned_data.append(reference_df[component].values)
else:
try:
aligned_series = align_series(reference_df[component].values, df_window[component].values)
all_aligned_data.append(aligned_series)
except:
pass
# Calculate average and standard deviation of aligned data
all_aligned_data = np.array(all_aligned_data)
avg_data = np.mean(all_aligned_data, axis=0)
# round float to avoid sqrt errors
def calculate_std(data):
if data is not None and len(data) > 0:
data = np.array(data)
std_data = np.std(data)
return std_data
else:
return "Data is empty or not a list"
std_data = calculate_std(all_aligned_data)
# Plot average line
axs[i].plot(reference_df['time_diff'], avg_data, color=colors[i], label='Average')
# Plot standard deviation as shaded region
try:
axs[i].fill_between(reference_df['time_diff'], avg_data - std_data, avg_data + std_data, color=colors[i], alpha=0.2)
except TypeError as e:
#print(f"Error: {e}")
pass
axs[i].axvline(x=0, color='red', linewidth=2, linestyle='--', label='Event Time')
axs[i].set_xlim(-window_hours, window_hours)
axs[i].legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.2, labelcolor='orange', edgecolor='orange')
axs[-1].set_xlabel('Hours from Event', color='orange')
fig.suptitle('Average Time Series of Components (FastDTW Aligned)', fontsize=16, color='orange')
plt.tight_layout()
plt.subplots_adjust(top=0.85, right=0.85, left=0.1)
if save_path:
fig.savefig(save_path, transparent=True, bbox_inches='tight')
plt.close(fig)
return save_path
else:
return fig
def plot_data_custom(df, date, save_path=None, subtitle=None):
df['datetime'] = pd.to_datetime(df['datetime'])
event = pd.to_datetime(date)
window = timedelta(hours=12)
x_min = event - window
x_max = event + window
fig, axs = plt.subplots(4, 1, figsize=(12, 12), sharex=True)
fig.patch.set_alpha(0) # Make figure background transparent
components = ['X', 'Y', 'Z', 'S']
colors = ['red', 'green', 'blue', 'black']
fig.text(0.02, 0.5, 'Geomagnetic Variation (nT)', va='center', rotation='vertical', color='orange')
# if df[component].isnull().all().all():
# return None
for i, component in enumerate(components):
axs[i].plot(df['datetime'], df[component], label=component, color=colors[i])
axs[i].axvline(x=event, color='red', linewidth=2, label='Event', linestyle='--')
axs[i].set_ylabel(component, color='orange', rotation=90)
axs[i].set_xlim(x_min, x_max)
axs[i].legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.2, labelcolor='orange', edgecolor='orange')
axs[i].grid(True, color='orange', alpha=0.3)
axs[i].patch.set_alpha(0) # Make subplot background transparent
for spine in axs[i].spines.values():
spine.set_color('orange')
axs[i].xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
axs[i].xaxis.set_major_locator(mdates.HourLocator(interval=1))
axs[i].tick_params(axis='both', colors='orange')
plt.setp(axs[-1].xaxis.get_majorticklabels(), rotation=45)
axs[-1].set_xlabel('Hours', color='orange')
fig.suptitle(f'Time Series of Components with Event Marks\n{subtitle}', fontsize=12, color='orange')
plt.tight_layout()
#plt.subplots_adjust(top=0.85)
plt.subplots_adjust(top=0.85, right=0.85, left=0.1)
if save_path:
fig.savefig(save_path, transparent=True)
plt.close(fig)
return save_path
else:
return fig
def batch_requests(stations, dataset, lon, lat, date, distance=100):
results = {"station": [], "data": [], "image": [], "custom_image": []}
all_data = []
all_event_times = []
for lon_, lat_, date_ in dataset[[lon, lat, date]].values:
test_lat_lon = (lat_, lon_)
try:
str_date = pd.to_datetime(date_).strftime('%Y-%m-%dT%H:%M:%S')
except:
str_date = date_
twelve_hours = pd.Timedelta(hours=12)
forty_eight_hours = pd.Timedelta(hours=48)
try:
str_date_start = (pd.to_datetime(str_date) - twelve_hours).strftime('%Y-%m-%dT%H:%M:%S')
str_date_end = (pd.to_datetime(str_date) + forty_eight_hours).strftime('%Y-%m-%dT%H:%M:%S')
except Exception as e:
print(f"Error: {e}")
pass
try:
new_dataset = compare_stations(test_lat_lon, stations, distance=distance, closest=True)
station_name = new_dataset['Name']
station_distance = new_dataset['Distance']
test_ = get_data(new_dataset.iloc[0]['IagaCode'], str_date_start, str_date_end)
if test_ and any(test_.get(key) for key in ['X', 'Y', 'Z', 'S']):
plotted = pd.DataFrame({
'datetime': test_['datetime'],
'X': test_.get('X', []),
'Y': test_.get('Y', []),
'Z': test_.get('Z', []),
'S': test_.get('S', []),
})
if plotted[['X', 'Y', 'Z', 'S']].any().any():
all_data.append(plotted)
all_event_times.append(pd.to_datetime(date_))
additional_data = f"Date: {date_}\nLat/Lon: {lat_}, {lon_}\nClosest station: {station_name.values[0]}\nDistance: {round(station_distance.values[0], 2)} km"
fig = plot_data_custom(plotted, date=pd.to_datetime(date_), save_path=None, subtitle=additional_data)
with st.status(f'Magnetic Data: {date_}', expanded=False) as status:
st.pyplot(fig)
status.update(f'Magnetic Data: {date_} - Finished!')
else:
print(f"No data for X, Y, Z, or S for date: {date_}")
except Exception as e:
#print(f"An error occurred: {e}")
pass
# if test_:
# results["station"].append(new_dataset.iloc[0]['IagaCode'])
# results["data"].append(test_)
# plotted = pd.DataFrame({
# 'datetime': test_['datetime'],
# 'X': test_['X'],
# 'Y': test_['Y'],
# 'Z': test_['Z'],
# 'S': test_['S'],
# })
# all_data.append(plotted)
# all_event_times.append(pd.to_datetime(date_))
# # print(date_)
# additional_data = f"Date: {date_}\nLat/Lon: {lat_}, {lon_}\nClosest station: {station_name.values[0]}\n Distance:{round(station_distance.values[0],2)} km"
# fig = plot_data_custom(plotted, date=pd.to_datetime(date_), save_path=None, subtitle =additional_data)
# with st.status(f'Magnetic Data: {date_}', expanded=False) as status:
# st.pyplot(fig)
# status.update(f'Magnetic Data: {date_} - Finished!')
# except Exception as e:
# #print(f"An error occurred: {e}")
# pass
if all_data:
fig_overlapped = plot_overlapped_timeseries(all_data, all_event_times)
display(fig_overlapped)
plt.close(fig_overlapped)
# fig_average = plot_average_timeseries(all_data, all_event_times)
# st.pyplot(fig_average)
fig_average_aligned = plot_average_timeseries_with_dtw(all_data, all_event_times)
with st.status(f'Dynamic Time Warping Data', expanded=False) as stts:
st.pyplot(fig_average_aligned)
return results
df = pd.DataFrame()
# Upload dataset
uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"])
if uploaded_file is not None:
if uploaded_file.name.endswith('.csv'):
df = pd.read_csv(uploaded_file)
else:
df = pd.read_excel(uploaded_file)
stations = get_stations()
st.write("Dataset Loaded:")
df = filter_dataframe(df)
st.dataframe(df)
# Select columns
with st.form(border=True, key='Select Columns for Analysis'):
lon_col = st.selectbox("Select Longitude Column", df.columns)
lat_col = st.selectbox("Select Latitude Column", df.columns)
date_col = st.selectbox("Select Date Column", df.columns)
distance = st.number_input("Enter Distance", min_value=0, value=100)
if st.form_submit_button("Process Data"):
cases = clean_uap_data(df, lat_col, lon_col, date_col)
results = batch_requests(stations, cases, lon_col, lat_col, date_col, distance=distance)