from PIL import Image import gradio as gr from nsfw_image_detector import NSFWDetector import torch classifier_nsfw = NSFWDetector(dtype=torch.bfloat16, device="cpu") # Define the inference function def classify_image(image, confidence_level): # Get predictions from both models result_nsfw_proba = classifier_nsfw.predict_proba(image) is_nsfw_method = result_nsfw_proba[0][confidence_level] >= 0.5 # Format NSFW probability scores proba_dict = result_nsfw_proba[0] nsfw_proba_str = "NSFW Probability Scores:\n" for level, score in proba_dict.items(): nsfw_proba_str += f"{level.value.title()}: {score:.4f}\n" # Format NSFW classification is_nsfw_str = f"NSFW Classification ({confidence_level.title()}):\n" is_nsfw_str += "🔴 True" if is_nsfw_method else "🟢 False" return nsfw_proba_str, is_nsfw_str # Create Gradio interface demo = gr.Interface( fn=classify_image, inputs=[ gr.Image(type="pil", label="Upload an image"), gr.Dropdown( choices=["low", "medium", "high"], value="medium", label="Low is the most restrictive, high is the least restrictive" ) ], outputs=[ gr.Textbox(label="NSFW Categories Scores", lines=3), gr.Textbox(label="NSFW Classification", lines=2), ], title="NSFW Image Classifier", description="Upload an image and select a confidence level to get a prediction using the Freepik/nsfw_image_detector model." ) # Launch app demo.launch()