Spaces:
Sleeping
Sleeping
from PIL import Image | |
import gradio as gr | |
from nsfw_image_detector import NSFWDetector | |
import torch | |
classifier_nsfw = NSFWDetector(dtype=torch.bfloat16, device="cpu") | |
# Define the inference function | |
def classify_image(image, confidence_level): | |
# Get predictions from both models | |
result_nsfw_proba = classifier_nsfw.predict_proba(image) | |
is_nsfw_method = result_nsfw_proba[0][confidence_level] >= 0.5 | |
# Format NSFW probability scores | |
proba_dict = result_nsfw_proba[0] | |
nsfw_proba_str = "NSFW Probability Scores:\n" | |
for level, score in proba_dict.items(): | |
nsfw_proba_str += f"{level.value.title()}: {score:.4f}\n" | |
# Format NSFW classification | |
is_nsfw_str = f"NSFW Classification ({confidence_level.title()}):\n" | |
is_nsfw_str += "π΄ True" if is_nsfw_method else "π’ False" | |
return nsfw_proba_str, is_nsfw_str | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=classify_image, | |
inputs=[ | |
gr.Image(type="pil", label="Upload an image"), | |
gr.Dropdown( | |
choices=["low", "medium", "high"], | |
value="medium", | |
label="Low is the most restrictive, high is the least restrictive" | |
) | |
], | |
outputs=[ | |
gr.Textbox(label="NSFW Categories Scores", lines=3), | |
gr.Textbox(label="NSFW Classification", lines=2), | |
], | |
title="NSFW Image Classifier", | |
description="Upload an image and select a confidence level to get a prediction using the Freepik/nsfw_image_detector model." | |
) | |
# Launch app | |
demo.launch() |