import gradio as gr import numpy as np import urllib.request from PIL import Image from io import BytesIO from tensorflow.keras.preprocessing import image from tensorflow.keras.models import load_model # Load the model model = load_model("my_model.h5", compile=False) # Common prediction function def classify_pil_image(pil_img): img = pil_img.resize((224, 224)) img = image.img_to_array(img) img = np.expand_dims(img, axis=0) img = img / 255.0 prediction = model.predict(img)[0] return { "CART": float(prediction[0]), "NSFW": float(prediction[1]), "SFW": float(prediction[2]) } # From file input (or example) def classify_uploaded_image(file): try: pil_img = Image.fromarray(file).convert("RGB") return classify_pil_image(pil_img) except Exception as e: return {"error": f"Upload error: {str(e)}"} # From URL input def classify_from_url(url): try: response = urllib.request.urlopen(url) img = Image.open(BytesIO(response.read())).convert("RGB") return classify_pil_image(img) except Exception as e: return {"error": f"URL error: {str(e)}"} # Example images for file-based interface examples = [[f"example{i}.jpg"] for i in range(1, 9)] # Upload tab (classic layout with examples) upload_interface = gr.Interface( fn=classify_uploaded_image, inputs=gr.Image(type="numpy", label="Upload or drag an image"), outputs=gr.Label(num_top_classes=3, label="Prediction"), examples=examples, title="Simple NSFW/SFW/CART Classifier", allow_flagging="never", cache_examples=False ) # URL tab (simple textbox interface) url_interface = gr.Interface( fn=classify_from_url, inputs=gr.Textbox(label="Paste Image URL"), outputs=gr.Label(num_top_classes=3, label="Prediction"), allow_flagging="never", cache_examples=False ) # Tabs wrapper to combine them gr.TabbedInterface( [upload_interface, url_interface], tab_names=["📤 Upload Image", "🌐 Image URL"] ).launch()