vision_chatbot / app.py
itsmemauli's picture
Update app.py
8739f35 verified
import streamlit as st
from transformers import pipeline
from PIL import Image, ImageDraw
import io
# -------------------------
# Load models once (cached)
# -------------------------
@st.cache_resource
def load_pipelines():
captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
detector = pipeline("object-detection", model="facebook/detr-resnet-50")
vqa = pipeline("visual-question-answering", model="Salesforce/blip-vqa-base")
return captioner, detector, vqa
captioner, detector, vqa = load_pipelines()
# -------------------------
# Streamlit UI
# -------------------------
st.set_page_config(page_title="Vision Chatbot+", page_icon="πŸ–ΌοΈ")
st.title("πŸ–ΌοΈ Vision Chatbot+")
st.write("Upload an image to get **captions, emojis, object detection (with boxes!), and Q&A**")
uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "jpeg", "png"])
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image", use_container_width=True)
# ---- Captioning ----
with st.spinner("Generating caption..."):
caption = captioner(image)[0]["generated_text"]
st.subheader("πŸ“ Caption")
st.success(caption)
# ---- Emoji Mode ----
emoji_map = {
"dog": "🐢", "cat": "🐱", "ball": "⚽", "frisbee": "πŸ₯",
"man": "πŸ‘¨", "woman": "πŸ‘©", "child": "πŸ§’",
"car": "πŸš—", "bicycle": "🚲", "horse": "🐎", "bird": "🐦",
"food": "πŸ”", "drink": "πŸ₯€", "tree": "🌳"
}
emoji_caption = " ".join(
emoji_map.get(word.lower(), word) for word in caption.split()
)
st.subheader("😎 Emoji Mode")
st.info(emoji_caption)
# ---- Object Detection ----
with st.spinner("Detecting objects..."):
detections = detector(image)
st.subheader("πŸ” Objects Detected")
# Draw bounding boxes
draw_img = image.copy()
draw = ImageDraw.Draw(draw_img)
for obj in detections:
box = obj["box"]
label = f"{obj['label']} ({obj['score']:.2f})"
# Draw rectangle
draw.rectangle(
[(box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])],
outline="red", width=3
)
# Add label above box
draw.text((box["xmin"], box["ymin"] - 10), label, fill="red")
st.write(f"- {label}")
st.image(draw_img, caption="Objects with bounding boxes", use_container_width=True)
# ---- Download Button ----
buf = io.BytesIO()
draw_img.save(buf, format="PNG")
byte_im = buf.getvalue()
st.download_button(
label="πŸ“₯ Download Annotated Image",
data=byte_im,
file_name="annotated_image.png",
mime="image/png"
)
# ---- Visual Question Answering ----
st.subheader("❓ Ask a Question About the Image")
user_q = st.text_input("Type your question (e.g., 'What is the dog doing?')")
if user_q:
with st.spinner("Thinking..."):
answer = vqa({"question": user_q, "image": image})
st.success(f"**Answer:** {answer[0]['answer']}")