|
from io import BytesIO |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import plotly.express as px |
|
import streamlit as st |
|
from hilbertcurve.hilbertcurve import HilbertCurve |
|
from sklearn.cluster import KMeans |
|
|
|
|
|
def cluster_sites_hilbert_curve_same_size( |
|
df: pd.DataFrame, |
|
lat_col: str, |
|
lon_col: str, |
|
region_col: str, |
|
max_sites: int = 25, |
|
mix_regions: bool = False, |
|
): |
|
clusters = [] |
|
cluster_id = 0 |
|
|
|
if not mix_regions: |
|
grouped = df.groupby(region_col) |
|
else: |
|
grouped = [("All", df)] |
|
|
|
|
|
p = 16 |
|
hilbert_curve = HilbertCurve(p, 2) |
|
|
|
for region, group in grouped: |
|
if len(group) == 0: |
|
continue |
|
|
|
|
|
lat_min, lat_max = group[lat_col].min(), group[lat_col].max() |
|
lon_min, lon_max = group[lon_col].min(), group[lon_col].max() |
|
|
|
group = group.copy() |
|
group["x"] = ((group[lat_col] - lat_min) / (lat_max - lat_min + 1e-10)) * ( |
|
2**p - 1 |
|
) |
|
group["y"] = ((group[lon_col] - lon_min) / (lon_max - lon_min + 1e-10)) * ( |
|
2**p - 1 |
|
) |
|
|
|
|
|
group["hilbert"] = group.apply( |
|
lambda row: hilbert_curve.distance_from_point( |
|
[int(row["x"]), int(row["y"])] |
|
), |
|
axis=1, |
|
) |
|
|
|
|
|
group = group.sort_values("hilbert") |
|
|
|
|
|
for i in range(0, len(group), max_sites): |
|
cluster = group.iloc[i : i + max_sites].copy() |
|
cluster["Cluster"] = f"C{cluster_id}" |
|
clusters.append(cluster) |
|
cluster_id += 1 |
|
|
|
result = pd.concat(clusters) |
|
return result.drop(columns=["x", "y", "hilbert"], errors="ignore") |
|
|
|
|
|
def cluster_sites_kmeans_lower_to_fixed_size( |
|
df: pd.DataFrame, |
|
lat_col: str, |
|
lon_col: str, |
|
region_col: str, |
|
max_sites: int = 25, |
|
mix_regions: bool = False, |
|
): |
|
clusters = [] |
|
cluster_id = 0 |
|
|
|
if not mix_regions: |
|
grouped = df.groupby(region_col) |
|
else: |
|
grouped = [("All", df)] |
|
|
|
for region, group in grouped: |
|
coords = group[[lat_col, lon_col]].to_numpy() |
|
remaining_sites = group.copy() |
|
|
|
while len(remaining_sites) > 0: |
|
|
|
n_remaining = len(remaining_sites) |
|
n_clusters = max(1, int(np.ceil(n_remaining / max_sites))) |
|
|
|
if n_remaining <= max_sites: |
|
|
|
cluster_group = remaining_sites.copy() |
|
cluster_group["Cluster"] = f"C{cluster_id}" |
|
clusters.append(cluster_group) |
|
cluster_id += 1 |
|
break |
|
else: |
|
|
|
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) |
|
labels = kmeans.fit_predict( |
|
remaining_sites[[lat_col, lon_col]].to_numpy() |
|
) |
|
|
|
|
|
temp_df = remaining_sites.copy() |
|
temp_df["Cluster"] = labels |
|
temp_df["Temp_Cluster"] = labels |
|
|
|
for cluster_num in range(n_clusters): |
|
cluster_group = temp_df[temp_df["Temp_Cluster"] == cluster_num] |
|
if len(cluster_group) <= max_sites: |
|
|
|
cluster_group = cluster_group.drop(columns=["Temp_Cluster"]) |
|
cluster_group["Cluster"] = f"C{cluster_id}" |
|
clusters.append(cluster_group) |
|
cluster_id += 1 |
|
|
|
remaining_sites = remaining_sites.drop(cluster_group.index) |
|
|
|
|
|
return pd.concat(clusters) |
|
|
|
|
|
def to_excel(df: pd.DataFrame) -> bytes: |
|
output = BytesIO() |
|
with pd.ExcelWriter(output, engine="xlsxwriter") as writer: |
|
df.to_excel(writer, index=False, sheet_name="Clusters") |
|
return output.getvalue() |
|
|
|
|
|
st.title("Automatic Site Clustering App") |
|
|
|
|
|
st.write( |
|
"""This app allows you to cluster sites based on their latitude and longitude. |
|
**Please choose a file containing the latitude and longitude region and site code columns.** |
|
""" |
|
) |
|
|
|
|
|
clustering_sample_file_path = "samples/Site_Clustering.xlsx" |
|
|
|
|
|
st.download_button( |
|
label="Download Clustering Sample File", |
|
data=open(clustering_sample_file_path, "rb").read(), |
|
file_name="Site_Clustering.xlsx", |
|
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", |
|
) |
|
|
|
uploaded_file = st.file_uploader("Upload your Excel file ", type=["xlsx"]) |
|
|
|
if uploaded_file: |
|
df = pd.read_excel(uploaded_file) |
|
st.write("Sample of uploaded data:", df.head()) |
|
|
|
columns = df.columns.tolist() |
|
|
|
with st.form("clustering_form"): |
|
lat_col = st.selectbox("Select Latitude column", columns) |
|
lon_col = st.selectbox("Select Longitude column", columns) |
|
region_col = st.selectbox("Select Region column", columns) |
|
code_col = st.selectbox("Select Site Code column", columns) |
|
max_sites = st.number_input( |
|
"Max sites per cluster", min_value=5, max_value=100, value=25 |
|
) |
|
cluster_method = st.selectbox( |
|
"Select clustering method", |
|
[ |
|
"Uniform number of sites for each cluster", |
|
"Number of sites Lower than max but not uniform", |
|
], |
|
) |
|
mix_regions = st.checkbox( |
|
"Allow mixing different regions in clusters", value=False |
|
) |
|
submitted = st.form_submit_button("Run Clustering") |
|
|
|
if submitted: |
|
if cluster_method == "Uniform number of sites for each cluster": |
|
clustered_df = cluster_sites_hilbert_curve_same_size( |
|
df, lat_col, lon_col, region_col, max_sites, mix_regions |
|
) |
|
elif cluster_method == "Number of sites Lower than max but not uniform": |
|
clustered_df = cluster_sites_kmeans_lower_to_fixed_size( |
|
df, lat_col, lon_col, region_col, max_sites, mix_regions |
|
) |
|
st.success("Clustering completed!") |
|
|
|
|
|
cluster_size = clustered_df["Cluster"].value_counts().sort_index() |
|
fig = px.bar(cluster_size, x=cluster_size.index, y=cluster_size.values) |
|
fig.update_layout(title="Cluster Size") |
|
st.plotly_chart(fig) |
|
|
|
|
|
cluster_size_per_region = ( |
|
clustered_df.groupby([region_col, "Cluster"]) |
|
.size() |
|
.reset_index(name="count") |
|
) |
|
fig = px.bar(cluster_size_per_region, x="Cluster", y="count", color=region_col) |
|
fig.update_layout(title="Cluster Size per Region") |
|
st.plotly_chart(fig) |
|
|
|
|
|
clustered_df["size"] = 10 |
|
fig = px.scatter_map( |
|
clustered_df, |
|
lat=lat_col, |
|
lon=lon_col, |
|
color="Cluster", |
|
size="size", |
|
hover_name=code_col, |
|
hover_data=[region_col], |
|
zoom=5, |
|
height=600, |
|
) |
|
fig.update_layout(mapbox_style="open-street-map") |
|
fig.update_traces(marker=dict(size=15)) |
|
st.plotly_chart(fig) |
|
|
|
|
|
st.download_button( |
|
label="Download clustered Excel file", |
|
data=to_excel(clustered_df), |
|
file_name="clustered_sites.xlsx", |
|
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", |
|
on_click="ignore", |
|
type="primary", |
|
) |
|
|