NeelTA's picture
url feature added
22ecc00
import gradio as gr
import numpy as np
import urllib.request
from PIL import Image
from io import BytesIO
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import load_model
# Load the model
model = load_model("my_model.h5", compile=False)
# Common prediction function
def classify_pil_image(pil_img):
img = pil_img.resize((224, 224))
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
img = img / 255.0
prediction = model.predict(img)[0]
return {
"CART": float(prediction[0]),
"NSFW": float(prediction[1]),
"SFW": float(prediction[2])
}
# From file input (or example)
def classify_uploaded_image(file):
try:
pil_img = Image.fromarray(file).convert("RGB")
return classify_pil_image(pil_img)
except Exception as e:
return {"error": f"Upload error: {str(e)}"}
# From URL input
def classify_from_url(url):
try:
response = urllib.request.urlopen(url)
img = Image.open(BytesIO(response.read())).convert("RGB")
return classify_pil_image(img)
except Exception as e:
return {"error": f"URL error: {str(e)}"}
# Example images for file-based interface
examples = [[f"example{i}.jpg"] for i in range(1, 9)]
# Upload tab (classic layout with examples)
upload_interface = gr.Interface(
fn=classify_uploaded_image,
inputs=gr.Image(type="numpy", label="Upload or drag an image"),
outputs=gr.Label(num_top_classes=3, label="Prediction"),
examples=examples,
title="Simple NSFW/SFW/CART Classifier",
allow_flagging="never",
cache_examples=False
)
# URL tab (simple textbox interface)
url_interface = gr.Interface(
fn=classify_from_url,
inputs=gr.Textbox(label="Paste Image URL"),
outputs=gr.Label(num_top_classes=3, label="Prediction"),
allow_flagging="never",
cache_examples=False
)
# Tabs wrapper to combine them
gr.TabbedInterface(
[upload_interface, url_interface],
tab_names=["πŸ“€ Upload Image", "🌐 Image URL"]
).launch()