import gradio as gr import pandas as pd from datasets import load_dataset, get_dataset_split_names from huggingface_hub import HfApi import os import pathlib import uuid import logging import threading import time import socket import uvicorn # --- Setup Logging --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # --- Embedding Atlas Imports --- from embedding_atlas.data_source import DataSource from embedding_atlas.server import make_server from embedding_atlas.projection import compute_text_projection from embedding_atlas.utils import Hasher # --- Helper functions --- def find_column_name(existing_names, candidate): if candidate not in existing_names: return candidate index = 1 while True: s = f"{candidate}_{index}" if s not in existing_names: return s index += 1 def find_available_port(start_port: int, max_attempts: int = 100): """Finds an available TCP port on the host.""" for port in range(start_port, start_port + max_attempts): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: if s.connect_ex(('127.0.0.1', port)) != 0: logging.info(f"Found available port: {port}") return port raise RuntimeError("Could not find an available port.") def run_atlas_server(app, port): """Target function for the background thread to run the Uvicorn server.""" logging.info(f"Starting Atlas server on http://127.0.0.1:{port}") uvicorn.run(app, host="127.0.0.1", port=port, log_level="warning") # --- Hugging Face API Helpers --- hf_api = HfApi() def get_user_datasets(username: str): logging.info(f"Fetching datasets for user: {username}") if not username: return gr.update(choices=[], value=None, interactive=False) try: datasets = hf_api.list_datasets(author=username, full=True) dataset_ids = [d.id for d in datasets if not d.private] logging.info(f"Found {len(dataset_ids)} datasets.") return gr.update(choices=sorted(dataset_ids), value=None, interactive=True) except Exception as e: logging.error(f"Failed to fetch datasets: {e}") return gr.update(choices=[], value=None, interactive=False) def get_dataset_splits(dataset_id: str): logging.info(f"Fetching splits for: {dataset_id}") if not dataset_id: return gr.update(choices=[], value=None, interactive=False) try: splits = get_dataset_split_names(dataset_id) logging.info(f"Found splits: {splits}") return gr.update(choices=splits, value=splits[0] if splits else None, interactive=True) except Exception as e: logging.error(f"Failed to fetch splits: {e}") return gr.update(choices=[], value=None, interactive=False) def get_split_columns(dataset_id: str, split: str): logging.info(f"Fetching columns for: {dataset_id}/{split}") if not dataset_id or not split: return gr.update(choices=[], value=None, interactive=False) try: dataset_sample = load_dataset(dataset_id, split=split, streaming=True) first_row = next(iter(dataset_sample)) columns = list(first_row.keys()) logging.info(f"Found columns: {columns}") preferred_cols = ['text', 'content', 'instruction', 'question', 'document', 'prompt'] best_col = next((col for col in preferred_cols if col in columns), columns[0] if columns else None) return gr.update(choices=columns, value=best_col, interactive=True) except Exception as e: logging.error(f"Failed to get columns: {e}", exc_info=True) return gr.update(choices=[], value=None, interactive=False) # --- Main Atlas Generation Logic --- def generate_atlas( dataset_name: str, split: str, text_column: str, sample_size: int, model_name: str, umap_neighbors: int, umap_min_dist: float, progress=gr.Progress(track_tqdm=True) ): if not all([dataset_name, split, text_column]): raise gr.Error("Please ensure a Dataset, Split, and Text Column are selected.") progress(0, desc="Loading dataset...") df = load_dataset(dataset_name, split=split).to_pandas() if sample_size > 0 and sample_size < len(df): df = df.sample(n=sample_size, random_state=42).reset_index(drop=True) progress(0.2, desc="Computing embeddings and UMAP...") x_col = find_column_name(df.columns, "projection_x") y_col = find_column_name(df.columns, "projection_y") neighbors_col = find_column_name(df.columns, "__neighbors") compute_text_projection( df, text_column, x=x_col, y=y_col, neighbors=neighbors_col, model=model_name, umap_args={"n_neighbors": umap_neighbors, "min_dist": umap_min_dist, "metric": "cosine", "random_state": 42}, ) progress(0.8, desc="Preparing Atlas data source...") id_col = find_column_name(df.columns, "_row_index") df[id_col] = range(df.shape[0]) metadata = {"columns": {"id": id_col, "text": text_column, "embedding": {"x": x_col, "y": y_col}, "neighbors": neighbors_col}} hasher = Hasher() hasher.update(f"{dataset_name}-{split}-{text_column}-{sample_size}-{model_name}-{uuid.uuid4()}") identifier = hasher.hexdigest() atlas_dataset = DataSource(identifier, df, metadata) progress(0.9, desc="Starting Atlas server...") static_path = str((pathlib.Path(__import__('embedding_atlas').__file__).parent / "static").resolve()) atlas_app = make_server(atlas_dataset, static_path=static_path, duckdb_uri="wasm") # Find an open port and run the server in a background thread port = find_available_port(start_port=8001) thread = threading.Thread(target=run_atlas_server, args=(atlas_app, port), daemon=True) thread.start() # Give the server a moment to start up time.sleep(2) iframe_html = f"" return gr.HTML(iframe_html) # --- Gradio UI Definition --- with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app: # UI elements... gr.Markdown("# Embedding Atlas Explorer") # ... (rest of the UI is the same as before) ... with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 1. Select Data") hf_user_input = gr.Textbox(label="Hugging Face User or Org Name", value="Trendyol", placeholder="e.g., 'gradio' or 'google'") dataset_input = gr.Dropdown(label="Select a Dataset", interactive=False) split_input = gr.Dropdown(label="Select a Split", interactive=False) text_column_input = gr.Dropdown(label="Select a Text Column", interactive=False) gr.Markdown("### 2. Configure Visualization") sample_size_input = gr.Slider(label="Number of Samples", minimum=0, maximum=10000, value=2000, step=100) with gr.Accordion("Advanced Settings", open=False): model_input = gr.Dropdown(label="Embedding Model", choices=["all-MiniLM-L6-v2", "all-mpnet-base-v2", "multi-qa-MiniLM-L6-cos-v1"], value="all-MiniLM-L6-v2") umap_neighbors_input = gr.Slider(label="UMAP Neighbors", minimum=2, maximum=100, value=15, step=1, info="Controls local vs. global structure.") umap_min_dist_input = gr.Slider(label="UMAP Min Distance", minimum=0.0, maximum=0.99, value=0.1, step=0.01, info="Controls how tightly points are packed.") generate_button = gr.Button("Generate Atlas", variant="primary") with gr.Column(scale=3): gr.Markdown("### 3. Explore Atlas") output_html = gr.HTML("
Atlas will be displayed here after generation.