import io import os import re import time import requests from typing import Any, Dict, List, Optional, Set, Union from difflib import get_close_matches from pathlib import Path from itertools import islice from functools import partial from multiprocessing.pool import ThreadPool from queue import Queue, Empty from typing import Callable, Iterable, Iterator, Optional, TypeVar import gradio as gr import pandas as pd import requests.exceptions from huggingface_hub import InferenceClient, create_repo, DatasetCard from huggingface_hub.utils import HfHubHTTPError import json # --- Configuration --- model_id = "microsoft/Phi-3-mini-4k-instruct" client = InferenceClient(model_id) save_dataset_hf_token = os.environ.get("SAVE_DATASET_HF_TOKEN") MAX_TOTAL_NB_ITEMS = 100 MAX_NB_ITEMS_PER_GENERATION_CALL = 10 NUM_ROWS = 100 NUM_VARIANTS = 10 NAMESPACE = "infinite-dataset-hub" URL = "https://huggingface.co/spaces/infinite-dataset-hub/infinite-dataset-hub" # --- Prompt Templates --- GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY = ( "A Machine Learning Practioner is looking for a dataset that matches '{search_query}'. " f"Generate a list of {MAX_NB_ITEMS_PER_GENERATION_CALL} names of quality datasets that don't exist but sound plausible and would " "be helpful. Feel free to reuse words from the query '{search_query}' to name the datasets. " "Every dataset should be about '{search_query}' and have descriptive tags/keywords including the ML task name associated with the dataset (classification, regression, anomaly detection, etc.). Use the following format:\n1. DatasetName1 (tag1, tag2, tag3)\n2. DatasetName2 (tag1, tag2, tag3)" ) GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS = ( "An ML practitioner is looking for a dataset CSV after the query '{search_query}'. " "Generate the first 5 rows of a plausible and quality CSV for the dataset '{dataset_name}'. " "You can get inspiration from related keywords '{tags}' but most importantly the dataset should correspond to the query '{search_query}'. " "Focus on quality text content and use a 'label' or 'labels' column if it makes sense (invent labels, avoid reusing the keywords, be accurate while labelling texts). " "Reply using a short description of the dataset with title **Dataset Description:** followed by the CSV content in a code block and with title **CSV Content Preview:**." ) GENERATE_MORE_ROWS = "Can you give me 10 additional samples in CSV format as well? Use the same CSV header '{csv_header}'." GENERATE_VARIANTS_WITH_RARITY_AND_LABEL = "Focus on generating samples for the label '{label}' and ideally generate {rarity} samples." GENERATE_VARIANTS_WITH_RARITY = "Focus on generating {rarity} samples." # --- Default Datasets for Landing Page --- landing_page_datasets_generated_text = """ 1. NewsEventsPredict (classification, media, trend) 2. FinancialForecast (economy, stocks, regression) 3. HealthMonitor (science, real-time, anomaly detection) 4. SportsAnalysis (classification, performance, player tracking) 5. SciLiteracyTools (language modeling, science literacy, text classification) 6. RetailSalesAnalyzer (consumer behavior, sales trend, segmentation) 7. SocialSentimentEcho (social media, emotion analysis, clustering) 8. NewsEventTracker (classification, public awareness, topical clustering) 9. HealthVitalSigns (anomaly detection, biometrics, prediction) 10. GameStockPredict (classification, finance, sports contingency) """ default_output = landing_page_datasets_generated_text.strip().split("\n") assert len(default_output) == MAX_NB_ITEMS_PER_GENERATION_CALL # --- Dataset Card Template --- DATASET_CARD_CONTENT = """ --- license: mit tags: - infinite-dataset-hub - synthetic --- {title} _Note: This is an AI-generated dataset so its content may be inaccurate or false_ {content} **Source of the data:** The dataset was generated using the [Infinite Dataset Hub]({url}) and {model_id} using the query '{search_query}': - **Dataset Generation Page**: {dataset_url} - **Model**: https://huggingface.co/{model_id} - **More Datasets**: https://huggingface.co/datasets?other=infinite-dataset-hub """ # --- Gradio HTML --- html = """ Infinite Dataset Hub

πŸ€— Infinite Dataset Hub ♾️

Generate datasets from AI and real-world data sources

Using: Real + AI Data

Toggle to switch between data sources

""" # --- Gradio CSS --- css = """ a { color: var(--body-text-color); } .datasetButton { justify-content: start; justify-content: left; } .tags { font-size: var(--button-small-text-size); color: var(--body-text-color-subdued); } .topButton { justify-content: start; justify-content: left; text-align: left; background: transparent; box-shadow: none; padding-bottom: 0; } .topButton::before { content: url("data:image/svg+xml,%3Csvg style='color: rgb(209 213 219)' xmlns='http://www.w3.org/2000/svg' xmlns:xlink='http://www.w3.org/1999/xlink' aria-hidden='true' focusable='false' role='img' width='1em' height='1em' preserveAspectRatio='xMidYMid meet' viewBox='0 0 25 25'%3E%3Cellipse cx='12.5' cy='5' fill='currentColor' fill-opacity='0.25' rx='7.5' ry='2'%3E%3C/ellipse%3E%3Cpath d='M12.5 15C16.6421 15 20 14.1046 20 13V20C20 21.1046 16.6421 22 12.5 22C8.35786 22 5 21.1046 5 20V13C5 14.1046 8.35786 15 12.5 15Z' fill='currentColor' opacity='0.5'%3E%3C/path%3E%3Cpath d='M12.5 7C16.6421 7 20 6.10457 20 5V11.5C20 12.6046 16.6421 13.5 12.5 13.5C8.35786 13.5 5 12.6046 5 11.5V5C5 6.10457 8.35786 7 12.5 7Z' fill='currentColor' opacity='0.5'%3E%3C/path%3E%3Cpath d='M5.23628 12C5.08204 12.1598 5 12.8273 5 13C5 14.1046 8.35786 15 12.5 15C16.6421 15 20 14.1046 20 13C20 12.8273 19.918 12.1598 19.7637 12C18.9311 12.8626 15.9947 13.5 12.5 13.5C9.0053 13.5 6.06886 12.8626 5.23628 12Z' fill='currentColor'%3E%3C/path%3E%3C/svg%3E"); margin-right: .25rem; margin-left: -.125rem; margin-top: .25rem; } .bottomButton { justify-content: start; justify-content: left; text-align: left; background: transparent; box-shadow: none; font-size: var(--button-small-text-size); color: var(--body-text-color-subdued); padding-top: 0; align-items: baseline; } .bottomButton::before { content: 'tags:'; margin-right: .25rem; } .buttonsGroup { background: transparent; } .buttonsGroup:hover { background: var(--input-background-fill); } .buttonsGroup div { background: transparent; } .insivibleButtonGroup { display: none; } @keyframes placeHolderShimmer { 0%{ background-position: -468px 0 } 100%{ background-position: 468px 0 } } .linear-background { animation-duration: 1s; animation-fill-mode: forwards; animation-iteration-count: infinite; animation-name: placeHolderShimmer; animation-timing-function: linear; background-image: linear-gradient(to right, var(--body-text-color-subdued) 8%, #dddddd11 18%, var(--body-text-color-subdued) 33%); background-size: 1000px 104px; color: transparent; background-clip: text; } .settings { background: transparent; } .settings button span { color: var(--body-text-color-subdued); } """ # --- Knowledge Base --- class KnowledgeBase: """Manages known entities (materials, colors) and patterns for data refinement.""" def __init__(self): self.materials: Set[str] = {'Metal', 'Wood', 'Plastic', 'Aluminum', 'Bronze', 'Steel', 'Glass', 'Leather', 'Fabric'} self.colors: Set[str] = {'Red', 'Black', 'White', 'Silver', 'Bronze', 'Yellow', 'Blue', 'Green', 'Gray', 'Brown'} self.patterns: Dict[str, List[str]] = {} self.source_data: Dict[str, Any] = {} def load_source(self, source_type: str, source_path: str) -> None: """Loads data from various sources and extracts knowledge.""" try: if source_type == 'csv_url': response = requests.get(source_path, timeout=10) response.raise_for_status() df = pd.read_csv(io.StringIO(response.text)) elif source_type == 'xlsx_url': response = requests.get(source_path, timeout=10) response.raise_for_status() df = pd.read_excel(io.BytesIO(response.content)) elif source_type == 'local_csv': df = pd.read_csv(source_path) elif source_type == 'local_xlsx': df = pd.read_excel(source_path) else: raise ValueError(f"Unsupported source type: {source_type}") self._extract_knowledge(df) self.source_data[source_path] = df.to_dict('records') except requests.exceptions.RequestException as e: raise ConnectionError(f"Failed to fetch data from URL: {e}") except ValueError as e: raise e except Exception as e: raise RuntimeError(f"Error loading source {source_path}: {str(e)}") def _extract_knowledge(self, df: pd.DataFrame) -> None: """Extracts known materials, colors, and column patterns.""" for column in df.columns: if 'material' in column.lower(): values = df[column].dropna().unique() self.materials.update(v.title() for v in values if isinstance(v, str)) elif 'color' in column.lower(): values = df[column].dropna().unique() self.colors.update(v.title() for v in values if isinstance(v, str)) if df[column].dtype == 'object': # Store string patterns for fuzzy matching patterns = df[column].dropna().astype(str).tolist() self.patterns[column] = patterns def get_closest_match(self, value: str, field_type: str) -> Optional[str]: """Finds the closest known value (material or color) for fuzzy matching.""" known_values = getattr(self, field_type + 's', set()) if not known_values: return None matches = get_close_matches(value.title(), list(known_values), n=1, cutoff=0.8) return matches[0] if matches else None knowledge_base = KnowledgeBase() # Global instance for refinement # --- Data Refinement Utilities --- def split_compound_field(field: str) -> List[str]: """Splits strings like 'Red, Blue' into ['Red', 'Blue'].""" parts = re.split(r'[,;\n]+', field) return list(set(p.strip().title() for p in parts if p.strip())) def normalize_value(value: Any, field_name: str, mode: str = 'sourceless', kb: Optional[KnowledgeBase] = None) -> Any: """Normalizes a single data value based on field name and refinement mode.""" if not isinstance(value, str): return value value = re.sub(r'\s+', ' ', value.strip()) # Normalize whitespace value = value.replace('_', ' ') # Replace underscores # Field-specific normalization logic if any(term in field_name.lower() for term in ['material']): parts = split_compound_field(value) if mode == 'sourced' and kb: known = [kb.get_closest_match(p, 'material') or p.title() for p in parts] else: known = [m for m in parts if m in kb.materials] if kb else parts return known[0] if len(known) == 1 else known elif any(term in field_name.lower() for term in ['color']): parts = split_compound_field(value) if mode == 'sourced' and kb: known = [kb.get_closest_match(p, 'color') or p.title() for p in parts] else: known = [c for c in parts if c in kb.colors] if kb else parts return known[0] if len(known) == 1 else known elif any(term in field_name.lower() for term in ['date', 'time']): return value # Placeholder elif any(term in field_name.lower() for term in ['type', 'status', 'category', 'description']): return value.title() # Title case for descriptive fields return value def clean_record(record: Dict[str, Any], mode: str = 'sourceless', kb: Optional[KnowledgeBase] = None) -> Dict[str, Any]: """Cleans and normalizes a single record, handling nesting and compound fields.""" cleaned = {} compound_fields_to_split = {} # Pass 1: Normalize values and identify compound fields for key, value in record.items(): clean_key = key.strip().lower().replace(" ", "_") if isinstance(value, str): # Detect potential compound fields for material in knowledge_base.materials: if material.lower() in value.lower(): compound_fields_to_split[clean_key] = value break # Recursively clean nested structures if isinstance(value, list): cleaned[clean_key] = [normalize_value(v, clean_key, mode, kb) for v in value] elif isinstance(value, dict): cleaned[clean_key] = clean_record(value, mode, kb) else: cleaned[clean_key] = normalize_value(value, clean_key, mode, kb) # Pass 2: Split identified compound fields for key, value in compound_fields_to_split.items(): parts = split_compound_field(value) materials = [p for p in parts if p in knowledge_base.materials] if materials: cleaned['material'] = materials[0] if len(materials) == 1 else materials remaining = [p for p in parts if p not in materials] if remaining: cleaned['condition'] = ' '.join(remaining) elif key not in cleaned: # If not processed and no known materials found cleaned[key] = value return cleaned def refine_data_generic(dataset: List[Dict[str, Any]], mode: str = 'sourceless', kb: Optional[KnowledgeBase] = None) -> List[Dict[str, Any]]: """Applies generic data refinement to a list of records, with optional knowledge base guidance.""" if mode == 'sourced' and kb and kb.patterns: # Apply fuzzy matching if sourced for record in dataset: for field, patterns in kb.patterns.items(): if field in record and isinstance(record[field], str): value = str(record[field]) matches = get_close_matches(value, patterns, n=1, cutoff=0.8) if matches: record[field] = matches[0] return [clean_record(entry, mode, kb) for entry in dataset] def refine_preview_data(df: pd.DataFrame, mode: str = 'sourceless') -> pd.DataFrame: """Refines the preview DataFrame based on the selected mode.""" # Remove common auto-generated index columns cols_to_drop = [] for col_name, values in df.to_dict(orient="series").items(): try: if all(isinstance(v, int) and v == i for i, (v, _) in enumerate(zip(values, df.index))): cols_to_drop.append(col_name) elif all(isinstance(v, int) and v == i + 1 for i, (v, _) in enumerate(zip(values, df.index))): cols_to_drop.append(col_name) except Exception: pass # Ignore non-sequential columns if cols_to_drop: df = df.drop(columns=cols_to_drop) records = df.to_dict('records') refined_records = refine_data_generic(records, mode=mode, kb=knowledge_base) return pd.DataFrame(refined_records) def detect_anomalies(record: Dict[str, Any]) -> List[str]: """Detects potential data quality issues (e.g., verbosity, missing values).""" flags = [] for k, v in record.items(): if isinstance(v, str): if len(v) > 300: flags.append(f"{k}: Too verbose.") if v.lower() in ['n/a', 'none', 'undefined', 'null', '']: flags.append(f"{k}: Missing value.") return flags def parse_preview_df(content: str) -> tuple[str, pd.DataFrame]: """Extracts CSV from response, parses, refines, and adds quality flags.""" csv_lines = [] in_csv_block = False for line in content.split("\n"): # Extract lines within CSV code blocks if line.strip().startswith("```csv") or line.strip().startswith("```"): in_csv_block = True; continue if line.strip().startswith("```"): in_csv_block = False; continue if in_csv_block: csv_lines.append(line) csv_content = "\n".join(csv_lines) if not csv_content: raise ValueError("No CSV content found.") csv_header = csv_content.split("\n")[0] if csv_content else "" df = parse_csv_df(csv_content) refined_df = refine_preview_data(df, mode='sourceless') # Initial refinement # Add quality flags refined_records = refined_df.to_dict('records') for record in refined_records: flags = detect_anomalies(record) if flags: record['_quality_flags'] = flags return csv_header, pd.DataFrame(refined_records) def parse_csv_df(csv: str, csv_header: Optional[str] = None) -> pd.DataFrame: """Safely parses CSV data using pandas with error handling and common fixes.""" csv = re.sub(r'''(?!")$$(["'][\w\s]+["'][, ]*)+$$(?!")''', lambda m: '"' + m.group(0).replace('"', "'") + '"', csv) # Fix unquoted lists if csv_header and csv.strip() and not csv.strip().startswith(csv_header.split(',')[0]): csv = csv_header + "\n" + csv # Prepend header if missing try: return pd.read_csv(io.StringIO(csv), skipinitialspace=True) except Exception as e: raise ValueError(f"Pandas CSV parsing error: {e}") # --- LLM Interaction Utilities --- T = TypeVar("T") def batched(it: Iterable[T], n: int) -> Iterator[list[T]]: """Yields chunks of size n from an iterable.""" it = iter(it) while batch := list(islice(it, n)): yield batch def stream_response(msg: str, history: list[Dict[str, str]] = [], max_tokens=500) -> Iterator[str]: """Streams responses from the LLM client with retry logic.""" messages = [{"role": m["role"], "content": m["content"]} for m in history] messages.append({"role": "user", "content": msg}) for attempt in range(3): # Retry mechanism try: for chunk in client.chat_completion(messages=messages, max_tokens=max_tokens, stream=True, top_p=0.8, seed=42): content = chunk.choices[0].delta.content if content: yield content break # Success except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e: print(f"LLM connection error (attempt {attempt+1}): {e}. Retrying in {2**attempt}s...") time.sleep(2**attempt) except Exception as e: print(f"Unexpected LLM error (attempt {attempt+1}): {e}. Retrying...") time.sleep(2**attempt) def generate_dataset_names(search_query: str, history: list[Dict[str, str]], is_real_data: bool = False, engine: Optional[str] = None) -> Iterator[str]: """Generates dataset names based on a search query using the LLM.""" query = search_query[:1000] if search_query else "" if is_real_data and engine: prompt = ( f"@Claude-3.7-Sonnet You are a data specialist who can transform real search results into structured datasets. " f"A user is searching for data about: \"{query}\" " f"Imagine you've queried {engine} and received real search results. Create a list of {MAX_NB_ITEMS_PER_GENERATION_CALL} specific datasets that could be created from these search results. " f"For each dataset: 1. Give it a clear, specific name related to the search topic. 2. Include 3-5 relevant tags in parentheses, with one tag specifying the ML task type (classification, regression, clustering, etc.). " f"Format each dataset as: 1. DatasetName (tag1, tag2, ml_task_tag). Make these datasets sound like real collections that could be created from {engine} search results on \"{query}\"." ) else: prompt = GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY.format(search_query=query) full_response = "" for token in stream_response(prompt, history): full_response += token yield token # Yield tokens for real-time display print(f"Generated dataset names for query '{search_query}'.") history.append({"role": "assistant", "content": full_response}) # Update history # No return needed as history is modified in place def generate_dataset_content(search_query: str, dataset_name: str, tags: str, history: list[Dict[str, str]], is_real_data: bool = False, engine: Optional[str] = None) -> Iterator[str]: """Generates the description and CSV preview for a dataset.""" query = search_query[:1000] if search_query else "" if is_real_data and engine: prompt = ( f"@Claude-3.7-Sonnet You're a specialist in converting web search results into structured data. " f"Based on search results from {engine} about \"{query}\", create a preview of the dataset \"{dataset_name}\" with tags \"{tags}\". " f"First, write a detailed description of what this dataset contains, its structure, and how it was constructed from web search results. " f"Then, generate a realistic 5-row CSV preview that resembles data you might get if you scraped and structured real results from {engine}. " f"Format your response with: **Dataset Description:** [detailed description] **CSV Content Preview:** ```csv [CSV header and 5 rows of realistic data] ``` " f"Include relevant columns for the dataset type, with proper labels/categories where appropriate. The data should look like it came from real sources." ) else: prompt = GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS.format( search_query=query, dataset_name=dataset_name, tags=tags ) full_response = "" for token in stream_response(prompt, history): full_response += token yield token print(f"Generated content for dataset '{dataset_name}'.") history.append({"role": "assistant", "content": full_response}) # Update history def _write_generator_to_queue(queue: Queue, func: Callable, kwargs: dict) -> None: """Helper to run a generator and put results (or errors) into a queue.""" try: for i, result in enumerate(func(**kwargs)): queue.put((i, result)) except Exception as e: queue.put((-1, str(e))) # Signal error with index -1 finally: queue.put(None) # Signal completion def iflatmap_unordered(func: Callable, kwargs_iterable: Iterable[dict]) -> Iterable[Any]: """Runs generator functions concurrently and yields results as they complete.""" queue = Queue() pool_size = min(len(kwargs_iterable), os.cpu_count() or 4) with ThreadPool(pool_size) as pool: async_results = [pool.apply_async(_write_generator_to_queue, (queue, func, kwargs)) for kwargs in kwargs_iterable] completed_generators = 0 while completed_generators < len(async_results): try: result = queue.get(timeout=0.1) if result is None: # Generator finished completed_generators += 1 continue index, data = result if index == -1: # Error occurred print(f"Generator error: {data}") continue # Skip this result yield data # Yield successful result except Empty: # Timeout occurred, check if all threads are done if all(res.ready() for res in async_results) and queue.empty(): break for res in async_results: res.get(timeout=0.1) # Ensure threads finish and raise exceptions def generate_partial_dataset( title: str, content: str, search_query: str, variant: str, csv_header: str, output: list[Optional[dict]], indices_to_generate: list[int], history: list[Dict[str, str]], is_real_data: bool = False, engine: Optional[str] = None ) -> Iterator[int]: """Generates a batch of dataset rows for a specific variant.""" dataset_name, tags = title.strip("# ").split("\ntags:", 1) dataset_name, tags = dataset_name.strip(), tags.strip() prompt = GENERATE_MORE_ROWS.format(csv_header=csv_header) + " " + variant # Construct initial messages for context initial_prompt = "" if is_real_data and engine: initial_prompt = ( f"@Claude-3.7-Sonnet You're a specialist in converting web search results into structured data. " f"Based on search results from {engine} about \"{search_query}\", create a preview of the dataset \"{dataset_name}\" with tags \"{tags}\". " f"First, write a detailed description of what this dataset contains, its structure, and how it was constructed from web search results. " f"Then, generate a realistic 5-row CSV preview that resembles data you might get if you scraped and structured real results from {engine}. " f"Format your response with: **Dataset Description:** [detailed description] **CSV Content Preview:** ```csv [CSV header and 5 rows of realistic data] ``` " f"Include relevant columns for the dataset type, with proper labels/categories where appropriate. The data should look like it came from real sources." ) else: initial_prompt = GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS.format( search_query=search_query, dataset_name=dataset_name, tags=tags ) messages = [ {"role": "user", "content": initial_prompt}, {"role": "assistant", "content": title + "\n\n" + content}, {"role": "user", "content": prompt}, ] generated_samples = 0 current_csv_chunk = "" in_csv_block = False for attempt in range(3): # Retry logic try: for chunk in client.chat_completion(messages=messages, max_tokens=1500, stream=True, top_p=0.8, seed=42): token = chunk.choices[0].delta.content if not token: continue current_csv_chunk += token # Detect CSV block start/end if token.strip().startswith("```csv") or token.strip().startswith("```"): in_csv_block = True continue if token.strip().startswith("```"): in_csv_block = False if current_csv_chunk.strip(): # Process accumulated chunk if block just ended try: temp_df = parse_csv_df(current_csv_chunk.strip(), csv_header=csv_header) new_rows = temp_df.iloc[generated_samples:].to_dict('records') for i, record in enumerate(new_rows): if generated_samples >= len(indices_to_generate): break refined_record = refine_data_generic([record])[0] flags = detect_anomalies(refined_record) if flags: refined_record['_quality_flags'] = flags output_index = indices_to_generate[generated_samples] if output_index < len(output): output[output_index] = refined_record generated_samples += 1 yield 1 # Signal progress except ValueError as e: print(f"CSV parsing error: {e}") except Exception as e: print(f"CSV chunk processing error: {e}") finally: current_csv_chunk = "" # Reset chunk continue if in_csv_block: # Process incrementally if inside CSV block try: temp_df = parse_csv_df(current_csv_chunk.strip(), csv_header=csv_header) new_rows = temp_df.iloc[generated_samples:].to_dict('records') for i, record in enumerate(new_rows): if generated_samples >= len(indices_to_generate): break refined_record = refine_data_generic([record])[0] flags = detect_anomalies(refined_record) if flags: refined_record['_quality_flags'] = flags output_index = indices_to_generate[generated_samples] if output_index < len(output): output[output_index] = refined_record generated_samples += 1 yield 1 except ValueError: pass # CSV not complete except Exception as e: print(f"Incremental CSV processing error: {e}") if generated_samples >= len(indices_to_generate): break # Target reached print(f"Retrying generation for variant '{variant}' (attempt {attempt+1})...") time.sleep(2**attempt) except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e: print(f"Connection error (attempt {attempt+1}): {e}. Retrying...") time.sleep(2**attempt) except Exception as e: print(f"Unexpected error (attempt {attempt+1}): {e}. Retrying...") time.sleep(2**attempt) def generate_variants(preview_df: pd.DataFrame) -> Iterator[str]: """Generates diverse prompts for creating dataset variants.""" label_cols = [col for col in preview_df.columns if "label" in col.lower()] labels = preview_df[label_cols[0]].unique() if label_cols and len(preview_df[label_cols[0]].unique()) > 1 else [] if labels: # Prioritize label-based generation rarities = ["pretty obvious", "common/regular", "unexpected but useful", "uncommon but still plausible", "rare/niche but still plausible"] for rarity in rarities: for label in labels: yield GENERATE_VARIANTS_WITH_RARITY_AND_LABEL.format(rarity=rarity, label=label) else: # Fallback to general rarity prompts rarities = ["obvious", "expected", "common", "regular", "unexpected but useful", "original but useful", "specific but not far-fetched", "uncommon but still plausible", "rare but still plausible", "very niche but still plausible"] for rarity in rarities: yield GENERATE_VARIANTS_WITH_RARITY.format(rarity=rarity) # --- Gradio Interface --- def whoami(token: str) -> Dict[str, Any]: """Fetches user information from Hugging Face Hub API.""" try: response = requests.get("https://huggingface.co/api/users/me", headers={"Authorization": f"Bearer {token}"}, timeout=5) response.raise_for_status() return response.json() except (requests.exceptions.RequestException, ValueError) as e: print(f"Error fetching user info: {e}") return {"name": "User", "orgs": []} def get_repo_visibility(repo_id: str, token: str) -> str: """Determines if a Hugging Face repository is public or private.""" try: response = requests.get(f"https://huggingface.co/api/repos/{repo_id}", headers={"Authorization": f"Bearer {token}"}, timeout=5) response.raise_for_status() return "public" if not response.json().get("private", False) else "private" except HfHubHTTPError as e: if e.response.status_code == 404: return "public" # Assume public if repo doesn't exist print(f"Error checking repo visibility for {repo_id}: {e}") return "public" except Exception as e: print(f"Unexpected error checking repo visibility for {repo_id}: {e}") return "public" with gr.Blocks(css=css) as demo: generated_texts_state = gr.State((landing_page_datasets_generated_text,)) # State for generated dataset names current_dataset_state = gr.State(None) # State to hold current dataset details for generation is_real_data_state = gr.State(True) # State to track if real data is being used current_engine_state = gr.State(None) # State to track the current search engine selected_engines_state = gr.State(["DuckDuckGo.com", "Bing.com", "Search.Yahoo.com", "Search.Brave.com", "Ecosia.org"]) # Default selected engines searchEngines = ["AlltheInternet.com", "DuckDuckGo.com", "Google.com", "Bing.com", "Search.Yahoo.com", "Startpage.com", "Qwant.com", "Ecosia.org", "WolframAlpha.com", "Mojeek.co.uk", "Search.Brave.com", "Yandex.com", "Baidu.com", "Gibiru.com", "MetaGer.org", "Swisscows.com", "Presearch.com", "Ekoru.org", "Search.Lilo.org"] # --- Search Page UI --- with gr.Column(visible=True, elem_id="search-page") as search_page: gr.Markdown("# πŸ€— Infinite Dataset Hub ♾️\n\nAn endless catalog of datasets, created just for you by an AI model.") with gr.Row(): search_bar = gr.Textbox(max_lines=1, placeholder="Search datasets, get infinite results", show_label=False, container=False, scale=9) search_button = gr.Button("πŸ”", variant="primary", scale=1) button_groups: list[gr.Group] = [] # Holds the groups for dataset buttons buttons: list[gr.Button] = [] # Holds the actual dataset name and tag buttons for i in range(MAX_TOTAL_NB_ITEMS): if i < len(default_output): # Use default datasets initially line = default_output[i] try: dataset_name, tags = line.split(".", 1)[1].strip(" )").split(" (", 1) except ValueError: dataset_name, tags = line.split(".", 1)[1].strip(" )").split(" ", 1)[0], "" group_classes, name_classes, tag_classes = "buttonsGroup", "topButton", "bottomButton" else: # Placeholders for future datasets dataset_name, tags = "⬜⬜⬜⬜⬜⬜", "β–‘β–‘β–‘β–‘, β–‘β–‘β–‘β–‘, β–‘β–‘β–‘β–‘" group_classes, name_classes, tag_classes = "buttonsGroup insivibleButtonGroup", "topButton linear-background", "bottomButton linear-background" with gr.Group(elem_classes=group_classes) as button_group: button_groups.append(button_group) dataset_btn = gr.Button(dataset_name, elem_classes=name_classes) tags_btn = gr.Button(tags, elem_classes=tag_classes) buttons.append(dataset_btn) buttons.append(tags_btn) load_more_datasets = gr.Button("Load more datasets") gr.Markdown(f"_powered by [{model_id}](https://huggingface.co/{model_id})_") # --- Settings Panel --- with gr.Column(scale=4, min_width="200px"): with gr.Accordion("Settings", open=False, elem_classes="settings"): gr.Markdown("Manage your Hugging Face account and dataset saving options.") gr.LoginButton() select_namespace_dropdown = gr.Dropdown(choices=[NAMESPACE], value=NAMESPACE, label="Hugging Face Namespace", visible=False) gr.Markdown("Dataset Generation Mode") refinement_mode = gr.Radio( ["sourceless", "sourced"], value="sourceless", label="Refinement Mode", info="Sourceless: AI generates data freely. Sourced: AI uses loaded data for context and refinement." ) with gr.Group(visible=False) as source_group: # Dynamic section for source loading source_type = gr.Dropdown( choices=["csv_url", "xlsx_url", "local_csv", "local_xlsx"], value="csv_url", label="Source Type", info="Select the format of your data source." ) source_path = gr.Textbox( label="Source Path/URL", placeholder="Enter URL or local file path", info="Provide the location of your dataset file." ) load_source_button = gr.Button("Load Source Data", icon="https://huggingface.co/datasets/huggingface/badges/resolve/main/badge-files/data.svg") source_status = gr.Markdown("", visible=False) visibility_radio = gr.Radio( ["public", "private"], value="public", container=False, interactive=False, label="Dataset Visibility", info="Set visibility for datasets saved to Hugging Face Hub." ) # Search Engine Settings gr.Markdown("Search Engine Configuration") data_source_toggle = gr.Checkbox(label="Use Real Search Data", value=True, info="Toggle to include results from real search engines.") engine_settings_button = gr.Button("Configure Search Engines", icon="https://img.icons8.com/ios-filled/50/000000/settings--v1.png", size="sm") # Engine Selection Modal with gr.Modal("Search Engine Settings", id="engine-modal") as engine_modal: gr.Markdown("Select which search engines to use for real data retrieval. A diverse selection improves results.") engine_options_html_comp = gr.HTML(elem_id="engine-options") with gr.Row(): select_all_engines_btn = gr.Button("Select All") deselect_all_engines_btn = gr.Button("Deselect All") save_engines_btn = gr.Button("Save Settings", variant="primary") # --- Dataset Detail Page UI --- with gr.Column(visible=False, elem_id="dataset-page") as dataset_page: gr.Markdown("# πŸ€— Infinite Dataset Hub ♾️\n\nAn endless catalog of datasets, created just for you.") dataset_title_md = gr.Markdown() # Dataset name and tags dataset_source_badge = gr.Markdown() # Badge indicating real/AI data dataset_source_info = gr.Markdown() # Details about the data source dataset_description_md = gr.Markdown() # Dataset description preview_table_comp = gr.DataFrame(visible=False, interactive=False, wrap=True) # Holds the preview CSV with gr.Row(): generate_full_dataset_button = gr.Button("Generate Full Dataset", variant="primary") save_dataset_button = gr.Button("πŸ’Ύ Save Dataset", variant="primary", visible=False) open_dataset_message = gr.Markdown("", visible=False) # Confirmation message dataset_share_button = gr.Button("Share Dataset URL") dataset_share_textbox = gr.Textbox(visible=False, show_copy_button=True, label="Copy this URL:", interactive=False, show_label=True) full_dataset_section = gr.Column(visible=False) # Container for full dataset and downloads full_table_comp = gr.DataFrame(visible=False, interactive=False, wrap=True) with gr.Row(): download_csv_button = gr.Button("Download CSV") download_json_button = gr.Button("Download JSON") download_parquet_button = gr.Button("Download Parquet") back_button = gr.Button("< Back", size="sm") # --- Event Handlers --- # Search Logic def _update_search_results(search_query: str, current_generated_texts: tuple[str], is_real_data: bool, engine: Optional[str]): """Handles dataset search and UI updates.""" # Reset UI to loading state yield {btn: gr.Button("⬜⬜⬜⬜⬜⬜", elem_classes="topButton linear-background") for btn in buttons[::2]} yield {btn: gr.Button("β–‘β–‘β–‘β–‘, β–‘β–‘β–‘β–‘, β–‘β–‘β–‘β–‘", elem_classes="bottomButton linear-background") for btn in buttons[1::2]} yield {group: gr.Group(elem_classes="buttonsGroup insivibleButtonGroup") for group in button_groups} generated_count = 0 new_texts = "" try: # Generate dataset names from LLM for line in generate_dataset_names(search_query, [], is_real_data=is_real_data, engine=engine): if "I'm sorry" in line or "policy" in line: raise gr.Error("Inappropriate content detected.") if generated_count >= MAX_NB_ITEMS_PER_GENERATION_CALL: break match = re.match(r"^\s*\d+\.\s+(.+?)\s+$$(.+?)$$", line) # Parse line format if match: dataset_name, tags = match.groups() dataset_name, tags = dataset_name.strip(), tags.strip() new_texts += line # Update buttons with generated data yield { buttons[2 * generated_count]: gr.Button(dataset_name, elem_classes="topButton"), buttons[2 * generated_count + 1]: gr.Button(tags, elem_classes="bottomButton"), } generated_count += 1 # Update state and make new buttons visible new_history = (current_generated_texts + (new_texts,)) if current_generated_texts else (landing_page_datasets_generated_text + "\n" + new_texts,) yield {generated_texts_state: new_history} yield {group: gr.Group(elem_classes="buttonsGroup") for group in button_groups[:generated_count]} except gr.Error as e: raise e # Propagate Gradio errors except Exception as e: raise gr.Error(f"Failed to generate datasets: {str(e)}") # Attach search handlers search_button.click( _update_search_results, inputs=[search_bar, generated_texts_state, is_real_data_state, current_engine_state], outputs=buttons + [generated_texts_state] + button_groups ) search_bar.submit( _update_search_results, inputs=[search_bar, generated_texts_state, is_real_data_state, current_engine_state], outputs=buttons + [generated_texts_state] + button_groups ) # Load More Datasets load_more_datasets.click( _update_search_results, inputs=[search_bar, generated_texts_state, is_real_data_state, current_engine_state], outputs=buttons + [generated_texts_state] + button_groups ) # Display Single Dataset Details def _show_dataset_details(search_query, dataset_name, tags, is_real_data, engine): """Switches to detail view and loads dataset content.""" yield { search_page: gr.Column(visible=False), dataset_page: gr.Column(visible=True), dataset_title_md: f"# {dataset_name}\n\n tags: {tags}", dataset_share_textbox: gr.Textbox(visible=False), full_dataset_section: gr.Column(visible=False), save_dataset_button: gr.Button(visible=False), open_dataset_message: gr.Markdown("", visible=False) } # Update source badge and info if is_real_data: badge_html = gr.Markdown(f'Real Data', visible=True) info_html = gr.Markdown(f'This dataset is based on real information queried from {engine} for the search term "{search_query}". The data has been structured for machine learning use.', visible=True) else: badge_html = gr.Markdown('AI-Generated', visible=True) info_html = gr.Markdown(f'This is an AI-generated dataset created using {model_id}. The content is synthetic and designed to represent plausible data related to "{search_query}".', visible=True) yield {dataset_source_badge: badge_html, dataset_source_info: info_html} # Stream content generation for content_chunk in generate_dataset_content(search_query, dataset_name, tags, [], is_real_data=is_real_data, engine=engine): yield {dataset_description_md: content_chunk} # Link buttons to the detail view function def _show_dataset_from_button_wrapper(search_query, *buttons_values): # Determine which button was clicked to get the index clicked_button_index = -1 for i, btn_val in enumerate(buttons_values): if btn_val is not None and btn_val != "": # Assuming non-empty value indicates the clicked button's text clicked_button_index = i break if clicked_button_index == -1: return # Should not happen if events are correctly wired # Determine if it was a name button (even index) or tag button (odd index) dataset_index = clicked_button_index // 2 dataset_name, tags = buttons_values[2 * dataset_index], buttons_values[2 * dataset_index + 1] is_real_data = current_engine_state.value is not None # Infer from engine state engine = current_engine_state.value if is_real_data else None yield from _show_dataset_details(search_query, dataset_name, tags, is_real_data, engine) # Wire up click events for all dataset name and tag buttons for i, (name_btn, tag_btn) in enumerate(batched(buttons, 2)): name_btn.click( partial(_show_dataset_from_button_wrapper), inputs=[search_bar, *buttons], outputs=[search_page, dataset_page, dataset_title_md, dataset_description_md, dataset_source_badge, dataset_source_info, dataset_share_textbox, full_dataset_section, save_dataset_button, open_dataset_message] ) tag_btn.click( partial(_show_dataset_from_button_wrapper), inputs=[search_bar, *buttons], outputs=[search_page, dataset_page, dataset_title_md, dataset_description_md, dataset_source_badge, dataset_source_info, dataset_share_textbox, full_dataset_section, save_dataset_button, open_dataset_message] ) # Back Button Navigation back_button.click(lambda: (gr.Column(visible=True), gr.Column(visible=False)), outputs=[search_page, dataset_page], js=""" function() { if ('parentIFrame' in window) { window.parentIFrame.scrollTo({top: 0, behavior:'smooth'}); } else { window.scrollTo({ top: 0, behavior: 'smooth' }); } return Array.from(arguments); } """) # Full Dataset Generation @generate_full_dataset_button.click( inputs=[dataset_title_md, dataset_description_md, search_bar, select_namespace_dropdown, visibility_radio, refinement_mode, is_real_data_state, current_engine_state], outputs=[full_table_comp, generate_full_dataset_button, save_dataset_button, full_dataset_section] ) def _generate_full_dataset(title_md, content_md, search_query, namespace, visibility, mode, is_real_data, engine): # Extract dataset name and tags from the markdown title try: dataset_name = title_md.split('\n')[0].strip('# ') tags = title_md.split('tags:', 1)[1].strip() except IndexError: raise gr.Error("Could not parse dataset title.") try: csv_header, preview_df = parse_preview_df(content_md) except ValueError as e: raise gr.Error(f"Failed to parse preview: {e}") refined_preview_df = refine_preview_data(preview_df, mode) columns = list(refined_preview_df) output_data: list[Optional[dict]] = [None] * NUM_ROWS # Initialize output structure initial_rows = refined_preview_df.to_dict('records') for i, record in enumerate(initial_rows): if i < NUM_ROWS: output_data[i] = {"idx": i, **record} # Update UI: show preview, disable generate, show save button yield { full_table_comp: gr.DataFrame(pd.DataFrame([r for r in output_data if r]), visible=True), generate_full_dataset_button: gr.Button(interactive=False), save_dataset_button: gr.Button(f"πŸ’Ύ Save {namespace}/{dataset_name}" + (" (private)" if visibility != "public" else ""), visible=True, interactive=False), full_dataset_section: gr.Column(visible=True) } # Prepare generation tasks for variants generation_tasks = [] variants = islice(generate_variants(refined_preview_df), NUM_VARIANTS) for i, variant in enumerate(variants): indices = list(range(len(initial_rows) + i, NUM_ROWS, NUM_VARIANTS)) if indices: # Only create task if there are rows to generate generation_tasks.append({ "func": generate_partial_dataset, "kwargs": { "title": title_md, "content": content_md, "search_query": search_query, "variant": variant, "csv_header": csv_header, "output": output_data, "indices_to_generate": indices, "history": [], # Use fresh history for each variant task "is_real_data": is_real_data, "engine": engine } }) # Execute tasks in parallel and update UI progressively for _ in iflatmap_unordered(lambda **kw: kw.pop('func')(**kw), generation_tasks): yield {full_table_comp: pd.DataFrame([r for r in output_data if r])} # Update DataFrame display yield {save_dataset_button: gr.Button(interactive=True)} # Enable save button print(f"Full dataset generation complete for {dataset_name}.") # Save Dataset to Hugging Face Hub @save_dataset_button.click( inputs=[dataset_title_md, dataset_description_md, search_bar, full_table_comp, select_namespace_dropdown, visibility_radio], outputs=[save_dataset_button, open_dataset_message] ) def _save_dataset(title_md, content_md, search_query, df, namespace, visibility, oauth_token): # Extract dataset name and tags from the markdown title try: dataset_name = title_md.split('\n')[0].strip('# ') tags = title_md.split('tags:', 1)[1].strip() except IndexError: raise gr.Error("Could not parse dataset title.") token = oauth_token.token if oauth_token else save_dataset_hf_token if not token: raise gr.Error("Login required or set SAVE_DATASET_HF_TOKEN.") repo_id = f"{namespace}/{dataset_name}" dataset_url_params = f"q={search_query.replace(' ', '+')}&dataset={dataset_name.replace(' ', '+')}&tags={tags.replace(' ', '+')}" dataset_url = f"{URL}?{dataset_url_params}" gr.Info("Saving dataset...") yield {save_dataset_button: gr.Button(interactive=False)} # Disable button during save try: create_repo(repo_id=repo_id, repo_type="dataset", private=visibility!="public", exist_ok=True, token=token) df.to_csv(f"hf://datasets/{repo_id}/data.csv", storage_options={"token": token}, index=False) card_content = DATASET_CARD_CONTENT.format(title=title_md, content=content_md, url=URL, dataset_url=dataset_url, model_id=model_id, search_query=search_query) DatasetCard(card_content).push_to_hub(repo_id=repo_id, repo_type="dataset", token=token) success_msg = f"# πŸŽ‰ Yay! Dataset saved to [{repo_id}](https://huggingface.co/datasets/{repo_id})!\n\n_PS: Check Settings to manage your saved datasets._" gr.Info("Dataset saved successfully.") yield {open_dataset_message: gr.Markdown(success_msg, visible=True)} except HfHubHTTPError as e: raise gr.Error(f"HF Hub error: {e.message}") except Exception as e: raise gr.Error(f"Save failed: {str(e)}") finally: yield {save_dataset_button: gr.Button(interactive=True)} # Re-enable button # Shareable URL Generation @dataset_share_button.click(inputs=[dataset_title_md, search_bar], outputs=[dataset_share_textbox]) def _show_share_url(title_md, search_query): try: dataset_name = title_md.split('\n')[0].strip('# ') tags = title_md.split('tags:', 1)[1].strip() except IndexError: raise gr.Error("Could not parse dataset title.") share_url = f"{URL}?q={search_query.replace(' ', '+')}&dataset={dataset_name.replace(' ', '+')}&tags={tags.replace(' ', '+')}" return gr.Textbox(share_url, visible=True) # Settings Toggles refinement_mode.change(lambda mode: gr.Group(visible=(mode == "sourced")), outputs=[source_group]) data_source_toggle.change(lambda value: (gr.State(value), gr.State(value if value else None)), inputs=[data_source_toggle], outputs=[is_real_data_state, current_engine_state]) @load_source_button.click(inputs=[source_type, source_path], outputs=[source_status]) def _load_source_data(source_type, source_path): if not source_path: raise gr.Error("Source path/URL is required.") try: knowledge_base.load_source(source_type, source_path) gr.Info("Source data loaded.") return gr.Markdown("βœ… Source loaded successfully", visible=True) except (ConnectionError, ValueError, RuntimeError) as e: raise gr.Error(f"Failed to load source: {str(e)}") # Engine Settings Modal Logic def _populate_engine_options(selected_engines): engine_options_html = "" for engine in searchEngines: is_checked = "checked" if engine in selected_engines else "" engine_options_html += f"""
""" return gr.HTML(engine_options_html) def _save_engine_settings(selected_engines_json): selected_engines = json.loads(selected_engines_json) if not selected_engines: gr.Warning("At least one search engine must be selected. Using DuckDuckGo as default.") selected_engines = ["DuckDuckGo.com"] current_engine = selected_engines[0] if selected_engines else None return gr.State(selected_engines), gr.State(current_engine), gr.Info(f"Updated search engines. Using {len(selected_engines)} engines.") # Initialize engine options component engine_options_html_comp = _populate_engine_options(selected_engines_state.value) # Update engine options when the modal is opened engine_settings_button.click(lambda: engine_options_html_comp.update(_populate_engine_options(selected_engines_state.value)), outputs=[engine_options_html_comp]) select_all_engines_btn.click(lambda: engine_options_html_comp.update(_populate_engine_options(searchEngines)), outputs=[engine_options_html_comp]) deselect_all_engines_btn.click(lambda: engine_options_html_comp.update(_populate_engine_options([])), outputs=[engine_options_html_comp]) save_engines_btn.click( _save_engine_settings, inputs=[gr.JSON(elem_id="engine-options")], # Capture checked engines from modal outputs=[selected_engines_state, current_engine_state, gr.Info()] ) engine_settings_button.click(lambda: engine_modal.update(visible=True), outputs=[engine_modal]) # Close modal on save or when clicking outside (implicit via Gradio's modal handling) # Initial App Load Logic @demo.load(outputs=([search_page, dataset_page, dataset_title_md, dataset_description_md, dataset_source_badge, dataset_source_info, dataset_share_textbox, full_dataset_section, save_dataset_button, open_dataset_message, search_bar] + # Outputs for detail page and search bar buttons + [generated_texts_state] + # Outputs for search results buttons and state [select_namespace_dropdown, visibility_radio, source_group, data_source_toggle, current_engine_state, selected_engines_state, engine_options_html_comp])) # Outputs for settings def _load_app(request: gr.Request, oauth_token: Optional[gr.OAuthToken]): # Handle user login and namespace selection if oauth_token: try: user_info = whoami(oauth_token.token) namespaces = [user_info["name"]] + [org["name"] for org in user_info.get("orgs", [])] yield { select_namespace_dropdown: gr.Dropdown(choices=namespaces, value=user_info["name"], visible=True), visibility_radio: gr.Radio(interactive=True), } except Exception: # Fallback if user info fails yield { select_namespace_dropdown: gr.Dropdown(choices=[NAMESPACE], value=NAMESPACE, visible=True), visibility_radio: gr.Radio(interactive=True), } else: # Default settings if not logged in yield { select_namespace_dropdown: gr.Dropdown(choices=[NAMESPACE], value=NAMESPACE, visible=True), visibility_radio: gr.Radio(interactive=False), } # Handle URL parameters for direct search or dataset loading query_params = dict(request.query_params) if "dataset" in query_params: is_real = query_params.get("engine") is not None engine = query_params.get("engine") yield from _show_dataset_details(query_params.get("q", query_params["dataset"]), query_params["dataset"], query_params.get("tags", ""), is_real, engine) yield {is_real_data_state: is_real, current_engine_state: engine} elif "q" in query_params: search_query = query_params["q"] is_real = query_params.get("engine") is not None engine = query_params.get("engine") yield {search_bar: search_query} yield {is_real_data_state: is_real, current_engine_state: engine} yield from _update_search_results(search_query, (), is_real, engine) else: yield {search_page: gr.Column(visible=True)} # Show search page by default # Initialize with default datasets initial_outputs = {} for i, line in enumerate(default_output): try: dataset_name, tags = line.split(".", 1)[1].strip(" )").split(" (", 1) except ValueError: dataset_name, tags = line.split(".", 1)[1].strip(" )").split(" ", 1)[0], "" initial_outputs[buttons[2 * i]] = gr.Button(dataset_name, elem_classes="topButton") initial_outputs[buttons[2 * i + 1]] = gr.Button(tags, elem_classes="bottomButton") initial_outputs[button_groups[i]] = gr.Group(elem_classes="buttonsGroup") yield initial_outputs yield {generated_texts_state: (landing_page_datasets_generated_text,)} # Initialize engine settings UI yield { data_source_toggle: gr.Checkbox(value=is_real_data_state.value), engine_options_html_comp: _populate_engine_options(selected_engines_state.value) } if __name__ == "__main__": demo.launch(share=False, server_name="0.0.0.0")