|
import gradio as gr |
|
import numpy as np |
|
import urllib.request |
|
from PIL import Image |
|
from io import BytesIO |
|
from tensorflow.keras.preprocessing import image |
|
from tensorflow.keras.models import load_model |
|
|
|
|
|
model = load_model("my_model.h5", compile=False) |
|
|
|
|
|
def classify_pil_image(pil_img): |
|
img = pil_img.resize((224, 224)) |
|
img = image.img_to_array(img) |
|
img = np.expand_dims(img, axis=0) |
|
img = img / 255.0 |
|
prediction = model.predict(img)[0] |
|
return { |
|
"CART": float(prediction[0]), |
|
"NSFW": float(prediction[1]), |
|
"SFW": float(prediction[2]) |
|
} |
|
|
|
|
|
def classify_uploaded_image(file): |
|
try: |
|
pil_img = Image.fromarray(file).convert("RGB") |
|
return classify_pil_image(pil_img) |
|
except Exception as e: |
|
return {"error": f"Upload error: {str(e)}"} |
|
|
|
|
|
def classify_from_url(url): |
|
try: |
|
response = urllib.request.urlopen(url) |
|
img = Image.open(BytesIO(response.read())).convert("RGB") |
|
return classify_pil_image(img) |
|
except Exception as e: |
|
return {"error": f"URL error: {str(e)}"} |
|
|
|
|
|
examples = [[f"example{i}.jpg"] for i in range(1, 9)] |
|
|
|
|
|
upload_interface = gr.Interface( |
|
fn=classify_uploaded_image, |
|
inputs=gr.Image(type="numpy", label="Upload or drag an image"), |
|
outputs=gr.Label(num_top_classes=3, label="Prediction"), |
|
examples=examples, |
|
title="Simple NSFW/SFW/CART Classifier", |
|
allow_flagging="never", |
|
cache_examples=False |
|
) |
|
|
|
|
|
url_interface = gr.Interface( |
|
fn=classify_from_url, |
|
inputs=gr.Textbox(label="Paste Image URL"), |
|
outputs=gr.Label(num_top_classes=3, label="Prediction"), |
|
allow_flagging="never", |
|
cache_examples=False |
|
) |
|
|
|
|
|
gr.TabbedInterface( |
|
[upload_interface, url_interface], |
|
tab_names=["π€ Upload Image", "π Image URL"] |
|
).launch() |
|
|