Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
from datasets import load_dataset, get_dataset_split_names | |
from huggingface_hub import HfApi | |
import os | |
import pathlib | |
import uuid | |
import logging | |
import threading | |
import time | |
import socket | |
import uvicorn | |
# --- Setup Logging --- | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
# --- Embedding Atlas Imports --- | |
from embedding_atlas.data_source import DataSource | |
from embedding_atlas.server import make_server | |
from embedding_atlas.projection import compute_text_projection | |
from embedding_atlas.utils import Hasher | |
# --- Helper functions --- | |
def find_column_name(existing_names, candidate): | |
if candidate not in existing_names: | |
return candidate | |
index = 1 | |
while True: | |
s = f"{candidate}_{index}" | |
if s not in existing_names: | |
return s | |
index += 1 | |
def find_available_port(start_port: int, max_attempts: int = 100): | |
"""Finds an available TCP port on the host.""" | |
for port in range(start_port, start_port + max_attempts): | |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
if s.connect_ex(('127.0.0.1', port)) != 0: | |
logging.info(f"Found available port: {port}") | |
return port | |
raise RuntimeError("Could not find an available port.") | |
def run_atlas_server(app, port): | |
"""Target function for the background thread to run the Uvicorn server.""" | |
logging.info(f"Starting Atlas server on http://127.0.0.1:{port}") | |
uvicorn.run(app, host="127.0.0.1", port=port, log_level="warning") | |
# --- Hugging Face API Helpers --- | |
hf_api = HfApi() | |
def get_user_datasets(username: str): | |
logging.info(f"Fetching datasets for user: {username}") | |
if not username: return gr.update(choices=[], value=None, interactive=False) | |
try: | |
datasets = hf_api.list_datasets(author=username, full=True) | |
dataset_ids = [d.id for d in datasets if not d.private] | |
logging.info(f"Found {len(dataset_ids)} datasets.") | |
return gr.update(choices=sorted(dataset_ids), value=None, interactive=True) | |
except Exception as e: | |
logging.error(f"Failed to fetch datasets: {e}") | |
return gr.update(choices=[], value=None, interactive=False) | |
def get_dataset_splits(dataset_id: str): | |
logging.info(f"Fetching splits for: {dataset_id}") | |
if not dataset_id: return gr.update(choices=[], value=None, interactive=False) | |
try: | |
splits = get_dataset_split_names(dataset_id) | |
logging.info(f"Found splits: {splits}") | |
return gr.update(choices=splits, value=splits[0] if splits else None, interactive=True) | |
except Exception as e: | |
logging.error(f"Failed to fetch splits: {e}") | |
return gr.update(choices=[], value=None, interactive=False) | |
def get_split_columns(dataset_id: str, split: str): | |
logging.info(f"Fetching columns for: {dataset_id}/{split}") | |
if not dataset_id or not split: return gr.update(choices=[], value=None, interactive=False) | |
try: | |
dataset_sample = load_dataset(dataset_id, split=split, streaming=True) | |
first_row = next(iter(dataset_sample)) | |
columns = list(first_row.keys()) | |
logging.info(f"Found columns: {columns}") | |
preferred_cols = ['text', 'content', 'instruction', 'question', 'document', 'prompt'] | |
best_col = next((col for col in preferred_cols if col in columns), columns[0] if columns else None) | |
return gr.update(choices=columns, value=best_col, interactive=True) | |
except Exception as e: | |
logging.error(f"Failed to get columns: {e}", exc_info=True) | |
return gr.update(choices=[], value=None, interactive=False) | |
# --- Main Atlas Generation Logic --- | |
def generate_atlas( | |
dataset_name: str, | |
split: str, | |
text_column: str, | |
sample_size: int, | |
model_name: str, | |
umap_neighbors: int, | |
umap_min_dist: float, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
if not all([dataset_name, split, text_column]): | |
raise gr.Error("Please ensure a Dataset, Split, and Text Column are selected.") | |
progress(0, desc="Loading dataset...") | |
df = load_dataset(dataset_name, split=split).to_pandas() | |
if sample_size > 0 and sample_size < len(df): | |
df = df.sample(n=sample_size, random_state=42).reset_index(drop=True) | |
progress(0.2, desc="Computing embeddings and UMAP...") | |
x_col = find_column_name(df.columns, "projection_x") | |
y_col = find_column_name(df.columns, "projection_y") | |
neighbors_col = find_column_name(df.columns, "__neighbors") | |
compute_text_projection( | |
df, text_column, x=x_col, y=y_col, neighbors=neighbors_col, model=model_name, | |
umap_args={"n_neighbors": umap_neighbors, "min_dist": umap_min_dist, "metric": "cosine", "random_state": 42}, | |
) | |
progress(0.8, desc="Preparing Atlas data source...") | |
id_col = find_column_name(df.columns, "_row_index") | |
df[id_col] = range(df.shape[0]) | |
metadata = {"columns": {"id": id_col, "text": text_column, "embedding": {"x": x_col, "y": y_col}, "neighbors": neighbors_col}} | |
hasher = Hasher() | |
hasher.update(f"{dataset_name}-{split}-{text_column}-{sample_size}-{model_name}-{uuid.uuid4()}") | |
identifier = hasher.hexdigest() | |
atlas_dataset = DataSource(identifier, df, metadata) | |
progress(0.9, desc="Starting Atlas server...") | |
static_path = str((pathlib.Path(__import__('embedding_atlas').__file__).parent / "static").resolve()) | |
atlas_app = make_server(atlas_dataset, static_path=static_path, duckdb_uri="wasm") | |
# Find an open port and run the server in a background thread | |
port = find_available_port(start_port=8001) | |
thread = threading.Thread(target=run_atlas_server, args=(atlas_app, port), daemon=True) | |
thread.start() | |
# Give the server a moment to start up | |
time.sleep(2) | |
iframe_html = f"<iframe src='http//127.0.0.1:{port}' width='100%' height='800px' frameborder='0'></iframe>" | |
return gr.HTML(iframe_html) | |
# --- Gradio UI Definition --- | |
with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app: | |
# UI elements... | |
gr.Markdown("# Embedding Atlas Explorer") | |
# ... (rest of the UI is the same as before) ... | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### 1. Select Data") | |
hf_user_input = gr.Textbox(label="Hugging Face User or Org Name", value="Trendyol", placeholder="e.g., 'gradio' or 'google'") | |
dataset_input = gr.Dropdown(label="Select a Dataset", interactive=False) | |
split_input = gr.Dropdown(label="Select a Split", interactive=False) | |
text_column_input = gr.Dropdown(label="Select a Text Column", interactive=False) | |
gr.Markdown("### 2. Configure Visualization") | |
sample_size_input = gr.Slider(label="Number of Samples", minimum=0, maximum=10000, value=2000, step=100) | |
with gr.Accordion("Advanced Settings", open=False): | |
model_input = gr.Dropdown(label="Embedding Model", choices=["all-MiniLM-L6-v2", "all-mpnet-base-v2", "multi-qa-MiniLM-L6-cos-v1"], value="all-MiniLM-L6-v2") | |
umap_neighbors_input = gr.Slider(label="UMAP Neighbors", minimum=2, maximum=100, value=15, step=1, info="Controls local vs. global structure.") | |
umap_min_dist_input = gr.Slider(label="UMAP Min Distance", minimum=0.0, maximum=0.99, value=0.1, step=0.01, info="Controls how tightly points are packed.") | |
generate_button = gr.Button("Generate Atlas", variant="primary") | |
with gr.Column(scale=3): | |
gr.Markdown("### 3. Explore Atlas") | |
output_html = gr.HTML("<div style='display:flex; justify-content:center; align-items:center; height:800px; border: 1px solid #ddd; border-radius: 5px;'><p>Atlas will be displayed here after generation.</p></div>") | |
# --- Event Listeners --- | |
hf_user_input.submit(fn=get_user_datasets, inputs=hf_user_input, outputs=dataset_input) | |
dataset_input.change(fn=get_dataset_splits, inputs=dataset_input, outputs=split_input) | |
split_input.change(fn=get_split_columns, inputs=[dataset_input, split_input], outputs=text_column_input) | |
generate_button.click( | |
fn=generate_atlas, | |
inputs=[dataset_input, split_input, text_column_input, sample_size_input, model_input, umap_neighbors_input, umap_min_dist_input], | |
outputs=[output_html], | |
) | |
app.load(fn=get_user_datasets, inputs=hf_user_input, outputs=dataset_input) | |
if __name__ == "__main__": | |
app.launch(debug=True) |