import gradio as gr from datasets import load_dataset, Dataset from difflib import ndiff import pandas as pd from semhash import SemHash from semhash.datamodels import DeduplicationResult from model2vec import StaticModel # Default parameters default_dataset_name = "SetFit/amazon_massive_scenario_en-US" default_dataset1_split = "train" default_dataset2_split = "test" default_text_column = "text" default_threshold = 0.9 # Load the model to use model = StaticModel.from_pretrained("minishlab/potion-base-8M") def display_word_differences(x: str, y: str) -> str: """ Display the word-level differences between two texts, formatted to avoid misinterpretation of Markdown syntax. """ diff = ndiff(x.split(), y.split()) formatted_diff = "\n".join(word for word in diff if word.startswith(("+", "-"))) return f"```\n{formatted_diff}\n```" def load_dataset_texts( dataset_name: str, dataset_split: str, text_column: str ) -> tuple[list[str], Dataset]: """Load texts from a specified dataset split.""" ds = load_dataset(dataset_name, split=dataset_split) return [example[text_column] for example in ds], ds def deduplicate_single_dataset( texts: list[str], threshold: float ) -> DeduplicationResult: """ Deduplicate within a single dataset using SemHash, treating each text as a raw string record. """ # Build a SemHash index from the raw texts semhash = SemHash.from_records(records=texts, model=model) # Deduplicate the entire dataset return semhash.self_deduplicate(threshold=threshold) def deduplicate_two_datasets( texts1: list[str], texts2: list[str], threshold: float ) -> DeduplicationResult: """Deduplicate dataset2 against dataset1, both as raw strings, using SemHash.""" # Build SemHash index on dataset1 semhash = SemHash.from_records(records=texts1, model=model) # Deduplicate texts2 against dataset1 return semhash.deduplicate(records=texts2, threshold=threshold) def create_deduplicated_dataset( original_dataset: Dataset, deduplicated_texts: list[str], text_column: str ) -> Dataset: """Create a new dataset with only the deduplicated texts.""" # Create a mapping from text to original row text_to_row = {row[text_column]: row for row in original_dataset} # Build new dataset with deduplicated texts deduplicated_rows = [] for text in deduplicated_texts: if text in text_to_row: deduplicated_rows.append(text_to_row[text]) return Dataset.from_list(deduplicated_rows) def perform_deduplication( deduplication_type: str, dataset1_name: str, dataset1_split: str, dataset1_text_column: str, dataset2_name: str = "", dataset2_split: str = "", dataset2_text_column: str = "", threshold: float = default_threshold, progress: gr.Progress = gr.Progress(track_tqdm=True), ): """ Perform deduplication on one or two datasets using SemHash. This function streams status updates to Gradio for user feedback. """ try: threshold = float(threshold) # Load Dataset 1 texts1, dataset1 = load_dataset_texts( dataset1_name, dataset1_split, dataset1_text_column ) if deduplication_type == "Single dataset": # Single-dataset deduplication result = deduplicate_single_dataset(texts1, threshold=threshold) # Sort all duplicates by score (ascending for least similar) for duprec in result.duplicates: duprec.duplicates.sort(key=lambda x: x[1]) # Create deduplicated dataset deduplicated_dataset = create_deduplicated_dataset( dataset1, result.deduplicated, dataset1_text_column ) # Summarize results num_duplicates = len(result.duplicates) deduplicated_count = len(result.deduplicated) total_docs = len(texts1) # Create examples table examples_table = None if num_duplicates > 0: # Only show duplicates that actually have near-duplicate records duplicates_with_data = [ duprec for duprec in result.duplicates if duprec.duplicates ] # sort duplicates by score (ascending for least similar) for duprec in result.duplicates: duprec.duplicates.sort(key=lambda x: x[1]) if duplicates_with_data: # Create table data for the 5 least similar examples table_data = [] for duprec in duplicates_with_data[:5]: dup_text = duprec.record orig_text, score = duprec.duplicates[0] table_data.append( [ orig_text[:200] + "..." if len(orig_text) > 200 else orig_text, dup_text[:200] + "..." if len(dup_text) > 200 else dup_text, f"{score:.4f}", ] ) examples_table = pd.DataFrame( table_data, columns=["Original Text", "Duplicate Text", "Similarity Score"], ) # Show success info with stats gr.Info( f"Deduplication completed! Found {num_duplicates} duplicates. " f"Dataset reduced from {total_docs} to {deduplicated_count} unique documents." ) # Return table with visibility update if examples_table is not None and not examples_table.empty: return deduplicated_dataset, gr.update( visible=True, value=examples_table ) else: return deduplicated_dataset, gr.update(visible=False) else: # Cross-dataset deduplication texts2, dataset2 = load_dataset_texts( dataset2_name, dataset2_split, dataset2_text_column ) result = deduplicate_two_datasets(texts1, texts2, threshold=threshold) # Sort duplicates by score (ascending for least similar) for duprec in result.duplicates: duprec.duplicates.sort(key=lambda x: x[1]) # Create deduplicated dataset from dataset2 deduplicated_dataset = create_deduplicated_dataset( dataset2, result.deduplicated, dataset2_text_column ) num_duplicates = len(result.duplicates) total_docs2 = len(texts2) deduplicated_count = len(result.deduplicated) # Create examples table examples_table = None if num_duplicates > 0: # Again, only show duplicates that have records duplicates_with_data = [ duprec for duprec in result.duplicates if duprec.duplicates ] if duplicates_with_data: # Create table data for the 5 least similar examples table_data = [] for duprec in duplicates_with_data[:5]: dup_text = duprec.record orig_text, score = duprec.duplicates[0] table_data.append( [ orig_text[:200] + "..." if len(orig_text) > 200 else orig_text, dup_text[:200] + "..." if len(dup_text) > 200 else dup_text, f"{score:.4f}", ] ) examples_table = pd.DataFrame( table_data, columns=[ "Original Text (Dataset 1)", "Duplicate Text (Dataset 2)", "Similarity Score", ], ) # Show success info with stats gr.Info( f"Deduplication completed! Found {num_duplicates} duplicates in Dataset 2. " f"Dataset reduced from {total_docs2} to {deduplicated_count} unique documents." ) # Return table with visibility update if examples_table is not None and not examples_table.empty: return deduplicated_dataset, gr.update( visible=True, value=examples_table ) else: return deduplicated_dataset, gr.update(visible=False) except Exception as e: gr.Error(f"An error occurred during deduplication: {str(e)}") return None, gr.update(visible=False) def push_to_hub( deduplicated_dataset: Dataset, output_dataset_name: str, oauth_profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None, progress: gr.Progress = gr.Progress(), ) -> str: """Push the deduplicated dataset to Hugging Face Hub.""" if oauth_token is None: raise gr.Error("Please log in with Hugging Face to push datasets to the Hub.") if not output_dataset_name.strip(): raise gr.Error("Please provide a dataset name.") if deduplicated_dataset is None: raise gr.Error( "No deduplicated dataset available. Please run deduplication first." ) try: progress(0.1, desc="Preparing dataset...") # Determine the full dataset name (username/dataset_name) username = oauth_profile.username if oauth_profile else None if "/" not in output_dataset_name and username: full_dataset_name = f"{username}/{output_dataset_name}" else: full_dataset_name = output_dataset_name progress(0.3, desc="Pushing to Hub...") # Push to hub using the OAuth token deduplicated_dataset.push_to_hub( full_dataset_name, token=oauth_token.token, private=False ) progress(1.0, desc="Complete!") gr.Info( f"Successfully pushed deduplicated dataset with {len(deduplicated_dataset)} rows to the Hub!" ) return ( f"✅ **Dataset published:** [{full_dataset_name}]" f"(https://huggingface.co/datasets/{full_dataset_name})" ) except Exception as e: raise gr.Error(f"Failed to push dataset to Hub: {str(e)}") def get_user_info(oauth_profile: gr.OAuthProfile | None) -> str: """Display user login status.""" if oauth_profile is None: return "Not logged in. Please log in to push datasets to the Hub." return f"Logged in as: **{oauth_profile.username}**" def update_push_button_state(oauth_profile: gr.OAuthProfile | None): """Update the push button state based on login status.""" is_logged_in = oauth_profile is not None return gr.update(interactive=is_logged_in) # --- Gradio App --- with gr.Blocks( theme=gr.themes.Ocean(), css="#status_output { height: 50px; overflow: auto; }" ) as demo: gr.Markdown("# SemDedup-My-Dataset: Semantic Text Deduplication Using SemHash") gr.Markdown(""" This demo showcases **semantic deduplication** using [SemHash](https://github.com/MinishLab/semhash) for HuggingFace datasets, using a [Model2Vec](https://github.com/MinishLab/model2vec) encoder. It can be used to identify duplicate texts within a **single dataset** or across **two datasets**. You can adjust the similarity threshold to control the strictness of the deduplication. """) deduplication_type = gr.Radio( choices=["Cross-dataset", "Single dataset"], label="Deduplication Type", value="Cross-dataset", # default ) with gr.Row(): dataset1_name = gr.Textbox(value=default_dataset_name, label="Dataset 1 Name") dataset1_split = gr.Textbox( value=default_dataset1_split, label="Dataset 1 Split" ) dataset1_text_column = gr.Textbox( value=default_text_column, label="Text Column Name" ) dataset2_inputs = gr.Column(visible=True) with dataset2_inputs: with gr.Row(): dataset2_name = gr.Textbox( value=default_dataset_name, label="Dataset 2 Name" ) dataset2_split = gr.Textbox( value=default_dataset2_split, label="Dataset 2 Split" ) dataset2_text_column = gr.Textbox( value=default_text_column, label="Text Column Name" ) threshold = gr.Slider( 0.0, 1.0, value=default_threshold, label="Similarity Threshold" ) with gr.Row(): compute_button = gr.Button("Deduplicate", variant="primary") status_output = gr.Markdown(elem_id="status_output") # Examples table examples_table = gr.Dataframe( headers=["Original Text", "Duplicate Text", "Similarity Score"], datatype=["str", "str", "str"], ) # Hidden state to store the deduplicated dataset deduplicated_dataset_state = gr.State() # Output dataset configuration gr.Markdown("## Push Deduplicated Dataset to Hub") with gr.Row(): with gr.Column(): output_dataset_name = gr.Textbox( label="Output Dataset Name", placeholder="my-deduplicated-dataset", info="Will be saved as username/dataset-name", ) with gr.Column(): push_button = gr.Button( "Push to Hub", variant="secondary", interactive=False ) login_button = gr.LoginButton() # Login section - moved below push to hub with gr.Row(): user_info = gr.Markdown() push_output = gr.Markdown() def update_visibility(choice: str): return gr.update(visible=(choice == "Cross-dataset")) deduplication_type.change( update_visibility, inputs=deduplication_type, outputs=dataset2_inputs ) # Update user info and button state when page loads or login status changes demo.load(get_user_info, inputs=None, outputs=user_info) demo.load(update_push_button_state, inputs=None, outputs=push_button) login_button.click(get_user_info, inputs=None, outputs=user_info) login_button.click(update_push_button_state, inputs=None, outputs=push_button) compute_button.click( fn=perform_deduplication, inputs=[ deduplication_type, dataset1_name, dataset1_split, dataset1_text_column, dataset2_name, dataset2_split, dataset2_text_column, threshold, ], outputs=[deduplicated_dataset_state, examples_table], ) push_button.click( fn=push_to_hub, inputs=[ deduplicated_dataset_state, output_dataset_name, ], outputs=push_output, ) demo.launch()