import requests from PIL import Image from io import BytesIO import torch from torchvision import transforms from transformers import AutoModelForImageClassification, AutoConfig import gradio as gr import spaces import os token = os.environ.get("HUGGINGFACE_HUB_TOKEN") model_id = "thelabel/240903-image-tagging" config = AutoConfig.from_pretrained(model_id, token=token) model = AutoModelForImageClassification.from_pretrained(model_id, token=token) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) image_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) def load_image_from_url(url): try: response = requests.get(url.strip(), timeout=10) response.raise_for_status() return Image.open(BytesIO(response.content)).convert("RGB") except Exception: return None @spaces.GPU def predict_tags(image_url, threshold=0.5): image = load_image_from_url(image_url) if image is None: return None, "Could not load image." image_tensor = image_transform(image).unsqueeze(0).to(device) with torch.no_grad(): logits = model(image_tensor).logits probs = torch.sigmoid(logits).squeeze() results = [ (config.idx_to_label[str(i)], float(probs[i])) for i in range(len(probs)) if probs[i] >= threshold ] results.sort(key=lambda x: x[1], reverse=True) return results, None def gradio_predict(urls, threshold): url_list = [u.strip() for u in urls.split(",") if u.strip()] output = [] for url in url_list: tags, error = predict_tags(url, threshold) if error or not tags: output.append({ "image_url": url, "error": error or "No tags above threshold." }) else: top_tag, top_score = tags[0] output.append({ "image_url": url, "tag_name": top_tag, "tag_score": round(top_score, 4) }) return str(output) # Return as string for textbox display demo = gr.Interface( fn=gradio_predict, inputs=[ gr.Textbox(label="Image URL(s) (comma-separated)"), gr.Slider(0, 1, value=0.5, step=0.01, label="Threshold"), ], outputs=gr.Textbox(label="Tags"), title="Batch Image Tagging with ViT", description="Paste one or more image URLs separated by commas to get predicted tags using thelabel/240903-image-tagging model.", ) if __name__ == "__main__": demo.launch()