File size: 1,545 Bytes
70ba646
 
45eaad4
 
70ba646
638ef29
45eaad4
70ba646
 
45eaad4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
638ef29
70ba646
 
 
 
45eaad4
 
 
 
 
 
 
 
 
638ef29
 
45eaad4
70ba646
45eaad4
70ba646
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from PIL import Image
import gradio as gr
from nsfw_image_detector import NSFWDetector
import torch


classifier_nsfw = NSFWDetector(dtype=torch.bfloat16, device="cpu")

# Define the inference function
def classify_image(image, confidence_level):
    # Get predictions from both models
    result_nsfw_proba = classifier_nsfw.predict_proba(image)
    is_nsfw_method = result_nsfw_proba[0][confidence_level] >= 0.5
    
    # Format NSFW probability scores
    proba_dict = result_nsfw_proba[0]
    nsfw_proba_str = "NSFW Probability Scores:\n"
    for level, score in proba_dict.items():
        nsfw_proba_str += f"{level.value.title()}: {score:.4f}\n"
    
    # Format NSFW classification
    is_nsfw_str = f"NSFW Classification ({confidence_level.title()}):\n"
    is_nsfw_str += "🔴 True" if is_nsfw_method else "🟢 False"
    
    
    return nsfw_proba_str, is_nsfw_str

# Create Gradio interface
demo = gr.Interface(
    fn=classify_image,
    inputs=[
        gr.Image(type="pil", label="Upload an image"),
        gr.Dropdown(
            choices=["low", "medium", "high"],
            value="medium",
            label="Low is the most restrictive, high is the least restrictive"
        )
    ],
    outputs=[
        gr.Textbox(label="NSFW Categories Scores", lines=3),
        gr.Textbox(label="NSFW Classification", lines=2),
    ],
    title="NSFW Image Classifier",
    description="Upload an image and select a confidence level to get a prediction using the Freepik/nsfw_image_detector model."
)

# Launch app
demo.launch()