import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import re from tokenizers.normalizers import Replace, Regex, Sequence, Strip import os model1_path = "https://huggingface.co/spaces/SzegedAI/AI_Detector/resolve/main/modernbert.bin" model2_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12" model3_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") try: tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base") model_1 = AutoModelForSequenceClassification.from_pretrained("answerdotai/ModernBERT-base", num_labels=41) model_1.load_state_dict(torch.hub.load_state_dict_from_url(model1_path, map_location=device, progress=True)) model_1.to(device).eval() model_2 = AutoModelForSequenceClassification.from_pretrained("answerdotai/ModernBERT-base", num_labels=41) model_2.load_state_dict(torch.hub.load_state_dict_from_url(model2_path, map_location=device, progress=True)) model_2.to(device).eval() model_3 = AutoModelForSequenceClassification.from_pretrained("answerdotai/ModernBERT-base", num_labels=41) model_3.load_state_dict(torch.hub.load_state_dict_from_url(model3_path, map_location=device, progress=True)) model_3.to(device).eval() except Exception as e: print(f"Error during model loading: {e}") print("Please ensure all model paths are correct, dependencies are installed, and you have an internet connection for remote models.") # Handle the error, e.g., by exiting or displaying an error in the UI if Gradio is already running. # For simplicity, we'll let it potentially crash if models can't load before Gradio starts. # If Gradio is already running, you'd need a more sophisticated error display. # For now, we'll just make sure the Gradio interface doesn't try to use non-existent models. tokenizer = None # Prevent further errors if tokenizer failed model_1, model_2, model_3 = None, None, None label_mapping = { 0: '13B', 1: '30B', 2: '65B', 3: '7B', 4: 'GLM130B', 5: 'bloom_7b', 6: 'bloomz', 7: 'cohere', 8: 'davinci', 9: 'dolly', 10: 'dolly-v2-12b', 11: 'flan_t5_base', 12: 'flan_t5_large', 13: 'flan_t5_small', 14: 'flan_t5_xl', 15: 'flan_t5_xxl', 16: 'gemma-7b-it', 17: 'gemma2-9b-it', 18: 'gpt-3.5-turbo', 19: 'gpt-35', 20: 'gpt4', 21: 'gpt4o', 22: 'gpt_j', 23: 'gpt_neox', 24: 'human', 25: 'llama3-70b', 26: 'llama3-8b', 27: 'mixtral-8x7b', 28: 'opt_1.3b', 29: 'opt_125m', 30: 'opt_13b', 31: 'opt_2.7b', 32: 'opt_30b', 33: 'opt_350m', 34: 'opt_6.7b', 35: 'opt_iml_30b', 36: 'opt_iml_max_1.3b', 37: 't0_11b', 38: 't0_3b', 39: 'text-davinci-002', 40: 'text-davinci-003' } def clean_text(text: str) -> str: text = re.sub(r'\s{2,}', ' ', text) text = re.sub(r'\s+([,.;:?!])', r'\1', text) return text if tokenizer: # Only set normalizer if tokenizer loaded successfully newline_to_space = Replace(Regex(r'\s*\n\s*'), " ") join_hyphen_break = Replace(Regex(r'(\w+)[--]\s*\n\s*(\w+)'), r"\1\2") # Corrected hyphen regex tokenizer.backend_tokenizer.normalizer = Sequence([ tokenizer.backend_tokenizer.normalizer, # Keep existing normalizers join_hyphen_break, newline_to_space, Strip() ]) # --- End Model & Tokenizer Configuration --- title_md = """
Developed by SzegedAI
""" description = """This tool utilizes the ModernBERT model to discern whether a given text is human-authored or AI-generated. It employs a soft voting ensemble of three models, amalgamating their outputs to enhance detection accuracy.
Paste your text into the field below to analyze its origin.
Error: Models not loaded. Please check the console.
" cleaned_text = clean_text(text) if not cleaned_text.strip(): # Check cleaned_text here result_message = "Please enter some text to analyze.
" return result_message inputs = tokenizer(cleaned_text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device) # Added max_length with torch.no_grad(): logits_1 = model_1(**inputs).logits logits_2 = model_2(**inputs).logits logits_3 = model_3(**inputs).logits softmax_1 = torch.softmax(logits_1, dim=1) softmax_2 = torch.softmax(logits_2, dim=1) softmax_3 = torch.softmax(logits_3, dim=1) averaged_probabilities = (softmax_1 + softmax_2 + softmax_3) / 3 probabilities = averaged_probabilities[0] ai_probs = probabilities.clone() human_label_index = -1 for k, v in label_mapping.items(): # Find the human label index dynamically if v.lower() == 'human': human_label_index = k break if human_label_index != -1: ai_probs[human_label_index] = 0 # Zero out human probability for AI sum human_prob_value = probabilities[human_label_index].item() * 100 else: # Fallback if 'human' not in label_mapping (should not happen with current map) human_prob_value = 0 print("Warning: 'human' label not found in label_mapping.") ai_total_prob = ai_probs.sum().item() * 100 # Recalculate human_prob based on ai_total_prob if necessary, # or ensure the logic correctly identifies human vs AI majority. # The original logic: human_prob = 100 - ai_total_prob might be confusing if ai_total_prob already excluded human. # Let's use the direct human probability from the model. ai_argmax_index = torch.argmax(ai_probs).item() # Argmax over non-human probabilities ai_argmax_model = label_mapping.get(ai_argmax_index, "Unknown AI") if human_prob_value > ai_total_prob : # Compare direct human probability with sum of AI probabilities result_message = ( f"The text is {human_prob_value:.2f}% likely Human written.
" ) else: result_message = ( f"The text is {ai_total_prob:.2f}% likely AI generated.
" f"Most Likely AI Source: {ai_argmax_model} (with {probabilities[ai_argmax_index].item()*100:.2f}% confidence among AI models)
" ) return result_message modern_css = """ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap'); :root { --primary-bg: #F8F9FA; --app-bg: #FFFFFF; --text-primary: #2C3E50; --text-secondary: #7F8C8D; --accent-color: #1ABC9C; --accent-color-darker: #16A085; --border-color: #E0E0E0; --input-bg: #FFFFFF; --input-focus-border: var(--accent-color); --human-color: #2ECC71; /* Green */ --human-bg: rgba(46, 204, 113, 0.1); --ai-color: #E74C3C; /* Red */ --ai-bg: rgba(231, 76, 60, 0.1); --shadow-color: rgba(44, 62, 80, 0.1); --container-max-width: 800px; /* Increased width */ --border-radius-md: 8px; --border-radius-lg: 12px; } body { font-family: 'Inter', sans-serif; background: linear-gradient(135deg, #f5f7fa 0%, #eef2f7 100%); color: var(--text-primary); margin: 0; padding: 20px; display: flex; justify-content: center; align-items: flex-start; min-height: 100vh; box-sizing: border-box; overflow-y: auto; } .gradio-container { background-color: var(--app-bg); border-radius: var(--border-radius-lg); padding: clamp(25px, 5vw, 40px); box-shadow: 0 8px 25px var(--shadow-color); max-width: var(--container-max-width); width: 100%; margin: 20px auto; border: none; } .form.svelte-633qhp, .block.svelte-11xb1hd, .gradio-html .block { /* More generic selector for Gradio HTML block */ background: none !important; border: none !important; box-shadow: none !important; padding: 0 !important; } /* Title and subtitle are now handled by Markdown with inline styles, h1 here is a fallback or for other h1s */ h1 { color: var(--text-primary); font-size: clamp(24px, 5vw, 30px); font-weight: 700; text-align: center; margin-bottom: 20px; /* Adjusted default h1 margin */ letter-spacing: -0.5px; } .app-description p { color: var(--text-secondary); font-size: clamp(14px, 2.5vw, 16px); line-height: 1.7; margin-bottom: 15px; } .app-description .instruction-text { font-weight: 500; color: var(--text-primary); margin-top: 20px; text-align: center; } .features-list { list-style: none; padding-left: 0; margin: 20px 0; } .features-list li { display: flex; align-items: center; font-size: clamp(14px, 2.5vw, 16px); color: var(--text-secondary); margin-bottom: 12px; line-height: 1.6; } .features-list .icon { margin-right: 12px; font-size: 1.2em; color: var(--accent-color); } .learn-more-link, .learn-more-link b { color: var(--accent-color) !important; text-decoration: none; font-weight: 600; } .learn-more-link:hover, .learn-more-link:hover b { color: var(--accent-color-darker) !important; text-decoration: underline; } #text_input_box textarea { background-color: var(--input-bg); border: 1px solid var(--border-color); border-radius: var(--border-radius-md); font-size: clamp(15px, 2.5vw, 16px); padding: 15px; width: 100%; box-sizing: border-box; color: var(--text-primary); transition: border-color 0.3s ease, box-shadow 0.3s ease; min-height: 120px; box-shadow: 0 2px 4px rgba(0,0,0,0.05); } #text_input_box textarea::placeholder { color: #B0BEC5; } #text_input_box textarea:focus { border-color: var(--input-focus-border); box-shadow: 0 0 0 3px rgba(26, 188, 156, 0.2); outline: none; } #result_output_box { background-color: var(--input-bg); /* Ensure background for the box */ border: 1px solid var(--border-color); border-radius: var(--border-radius-md); padding: 20px; margin-top: 25px; width: 100%; box-sizing: border-box; text-align: center; font-size: clamp(16px, 3vw, 17px); /* Slightly adjusted font size for results */ box-shadow: 0 4px 8px rgba(0,0,0,0.05); min-height: 80px; /* Give it some min height */ display: flex; /* For centering content if needed */ flex-direction: column; justify-content: center; } #result_output_box p { /* Style paragraphs inside the result box */ margin-bottom: 8px; /* Space between lines in result */ line-height: 1.6; } #result_output_box p:last-child { margin-bottom: 0; } .highlight-human, .highlight-ai { font-weight: 600; padding: 5px 10px; /* Adjusted padding */ border-radius: var(--border-radius-md); display: inline-block; font-size: 1.05em; /* Adjusted size */ } .highlight-human { color: var(--human-color); background-color: var(--human-bg); /* border: 1px solid var(--human-color); Removed border for cleaner look */ } .highlight-ai { color: var(--ai-color); background-color: var(--ai-bg); /* border: 1px solid var(--ai-color); Removed border for cleaner look */ } .tabs > div:first-child button { background-color: transparent !important; color: var(--text-secondary) !important; border: none !important; border-bottom: 2px solid transparent !important; border-radius: 0 !important; padding: 10px 15px !important; font-weight: 500 !important; transition: color 0.3s ease, border-bottom-color 0.3s ease !important; } .tabs > div:first-child button.selected { color: var(--accent-color) !important; border-bottom-color: var(--accent-color) !important; font-weight: 600 !important; } .gr-examples { padding: 15px !important; border: 1px solid var(--border-color) !important; border-radius: var(--border-radius-md) !important; background-color: #fdfdfd !important; margin-top: 10px; /* Add some space above examples */ } .gr-sample-textbox { border: 1px solid var(--border-color) !important; border-radius: var(--border-radius-md) !important; font-size: 14px !important; } .gr-accordion > .label-wrap button { /* Style accordion label */ font-weight: 500 !important; color: var(--text-primary) !important; } .footer-text, #bottom_text { text-align: center; margin-top: 40px; font-size: clamp(13px, 2vw, 14px); color: var(--text-secondary); } #bottom_text p { margin: 0; } @media (max-width: 768px) { body { padding: 10px; align-items: flex-start; } .gradio-container { padding: 20px; margin: 10px; } h1 { font-size: 22px; } /* Adjust for custom title markdown */ .app-description p, .features-list li { font-size: 14px; } #text_input_box textarea { font-size: 15px; min-height: 100px; } #result_output_box { font-size: 15px; padding: 15px; } } """ iface = gr.Blocks(css=modern_css, theme=gr.themes.Base(font=[gr.themes.GoogleFont("Inter"), "sans-serif"])) with iface: gr.Markdown(title_md) # Using combined Markdown for title and subtitle gr.Markdown(description) text_input = gr.Textbox( label="", placeholder="Type or paste your content here...", elem_id="text_input_box", lines=7 # Adjusted lines ) result_output = gr.HTML(elem_id="result_output_box") # Only set up the change function if models are loaded if all([tokenizer, model_1, model_2, model_3]): text_input.change(classify_text_interface, inputs=text_input, outputs=result_output) else: # Display a persistent error if models couldn't load gr.HTML("Application Error: Models could not be loaded. Please check the server console for details.