import streamlit as st from transformers import pipeline from PIL import Image, ImageDraw import io # ------------------------- # Load models once (cached) # ------------------------- @st.cache_resource def load_pipelines(): captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") detector = pipeline("object-detection", model="facebook/detr-resnet-50") vqa = pipeline("visual-question-answering", model="Salesforce/blip-vqa-base") return captioner, detector, vqa captioner, detector, vqa = load_pipelines() # ------------------------- # Streamlit UI # ------------------------- st.set_page_config(page_title="Vision Chatbot+", page_icon="🖼️") st.title("🖼️ Vision Chatbot+") st.write("Upload an image to get **captions, emojis, object detection (with boxes!), and Q&A**") uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "jpeg", "png"]) if uploaded_file: image = Image.open(uploaded_file).convert("RGB") st.image(image, caption="Uploaded Image", use_container_width=True) # ---- Captioning ---- with st.spinner("Generating caption..."): caption = captioner(image)[0]["generated_text"] st.subheader("📝 Caption") st.success(caption) # ---- Emoji Mode ---- emoji_map = { "dog": "🐶", "cat": "🐱", "ball": "⚽", "frisbee": "🥏", "man": "👨", "woman": "👩", "child": "🧒", "car": "🚗", "bicycle": "🚲", "horse": "🐎", "bird": "🐦", "food": "🍔", "drink": "🥤", "tree": "🌳" } emoji_caption = " ".join( emoji_map.get(word.lower(), word) for word in caption.split() ) st.subheader("😎 Emoji Mode") st.info(emoji_caption) # ---- Object Detection ---- with st.spinner("Detecting objects..."): detections = detector(image) st.subheader("🔍 Objects Detected") # Draw bounding boxes draw_img = image.copy() draw = ImageDraw.Draw(draw_img) for obj in detections: box = obj["box"] label = f"{obj['label']} ({obj['score']:.2f})" # Draw rectangle draw.rectangle( [(box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])], outline="red", width=3 ) # Add label above box draw.text((box["xmin"], box["ymin"] - 10), label, fill="red") st.write(f"- {label}") st.image(draw_img, caption="Objects with bounding boxes", use_container_width=True) # ---- Download Button ---- buf = io.BytesIO() draw_img.save(buf, format="PNG") byte_im = buf.getvalue() st.download_button( label="📥 Download Annotated Image", data=byte_im, file_name="annotated_image.png", mime="image/png" ) # ---- Visual Question Answering ---- st.subheader("❓ Ask a Question About the Image") user_q = st.text_input("Type your question (e.g., 'What is the dog doing?')") if user_q: with st.spinner("Thinking..."): answer = vqa({"question": user_q, "image": image}) st.success(f"**Answer:** {answer[0]['answer']}")