Spaces:
Sleeping
Sleeping
import os | |
import json | |
import streamlit as st | |
from PIL import Image, UnidentifiedImageError, ExifTags | |
import requests | |
from io import BytesIO | |
import wikipedia | |
from BharatCaptioner import identify_landmark | |
from groq import Groq | |
import hashlib | |
import time # To simulate character-by-character display | |
# Initialize Groq API client | |
os.environ["GROQ_API_KEY"] = "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" | |
client = Groq() | |
st.title("BharatCaptioner with Conversational Chatbot") | |
st.write( | |
"A tool to identify/describe Indian Landmarks in Indic Languages and chat about the image." | |
) | |
# Sidebar details | |
st.sidebar.title("Developed by Harsh Sanga") | |
st.sidebar.write( | |
"**For the Code**: [GitHub Repo](https://github.com/h-sanga)" | |
) | |
st.sidebar.write( | |
"**Connect with me**: [LinkedIn](https://www.linkedin.com/in/harsh-sanga-2375a9272/)" | |
) | |
# Image upload or URL input | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
url = st.text_input("Or enter a valid image URL...") | |
# Initialize session state variables | |
if "image_hash" not in st.session_state: | |
st.session_state["image_hash"] = None | |
if "chat_history" not in st.session_state: | |
st.session_state["chat_history"] = [] | |
if "chatbot_started" not in st.session_state: | |
st.session_state["chatbot_started"] = False | |
image = None | |
error_message = None | |
landmark = None | |
summary = None | |
caption = None | |
# Function to correct image orientation | |
def correct_image_orientation(img): | |
try: | |
for orientation in ExifTags.TAGS.keys(): | |
if ExifTags.TAGS[orientation] == "Orientation": | |
break | |
exif = img._getexif() | |
if exif is not None: | |
orientation = exif[orientation] | |
if orientation == 3: | |
img = img.rotate(180, expand=True) | |
elif orientation == 6: | |
img = img.rotate(270, expand=True) | |
elif orientation == 8: | |
img = img.rotate(90, expand=True) | |
except (AttributeError, KeyError, IndexError): | |
pass | |
return img | |
# Function to get a unique hash for the image | |
def get_image_hash(image): | |
img_bytes = image.tobytes() | |
return hashlib.md5(img_bytes).hexdigest() | |
# Check if new image or URL is uploaded and reset the chat history if necessary | |
def reset_chat_if_new_image(): | |
global image, landmark, summary, caption | |
new_image_hash = None | |
# Process the new image or URL | |
if uploaded_file: | |
image = Image.open(uploaded_file) | |
image = correct_image_orientation(image) | |
new_image_hash = get_image_hash(image) | |
elif url: | |
try: | |
response = requests.get(url) | |
response.raise_for_status() | |
image = Image.open(BytesIO(response.content)) | |
image = correct_image_orientation(image) | |
new_image_hash = get_image_hash(image) | |
except (requests.exceptions.RequestException, UnidentifiedImageError): | |
image = None | |
new_image_hash = None | |
error_message = ( | |
"Error: The provided URL is invalid or the image could not be loaded." | |
) | |
st.error(error_message) | |
else: | |
image = None | |
# If the image is new, reset the chat and session state | |
if new_image_hash and new_image_hash != st.session_state["image_hash"]: | |
st.session_state["image_hash"] = new_image_hash | |
st.session_state["chat_history"] = [] | |
st.session_state["chatbot_started"] = False # Reset chatbot status | |
return image | |
# Call the reset function to check for new images or URL | |
image = reset_chat_if_new_image() | |
# If an image is provided | |
if image is not None: | |
# Keep the original image size for processing | |
original_image = image.copy() # Create a copy for identification | |
# Identify the landmark using BharatCaptioner | |
landmark, prob = identify_landmark(original_image) | |
summary = wikipedia.summary(landmark, sentences=3) # Shortened summary | |
st.write(f"**Landmark Identified:** {landmark}") | |
# Display a smaller version of the image in the sidebar | |
with st.sidebar: | |
small_image = original_image.resize((128, 128)) # Resize for display | |
st.image(small_image, caption=f"Landmark: {landmark}", use_column_width=True) | |
# st.write(f"**Landmark:** {landmark}") | |
# Display the original image before the conversation | |
st.image(original_image, caption=f"Image of {landmark}", use_column_width=True) | |
# Chatbot functionality | |
st.write("### Chat with the Chatbot about the Image") | |
caption = f"The landmark in the image is {landmark}. {summary}" | |
# Chatbot introduction message | |
if not st.session_state["chatbot_started"]: | |
chatbot_intro = f"Hello! I see the image is of **{landmark}**. {summary} **Would you like to know more** about this landmark?" | |
st.session_state["chat_history"].append( | |
{"role": "assistant", "content": chatbot_intro} | |
) | |
st.session_state["chatbot_started"] = True | |
# Display chat history | |
for message in st.session_state.chat_history: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# User input | |
user_prompt = st.chat_input("Ask the Chatbot about the image...") | |
if user_prompt: | |
st.session_state["chat_history"].append({"role": "user", "content": user_prompt}) | |
st.chat_message("user").markdown(user_prompt) | |
# Send the user's message to the chatbot | |
messages = [ | |
{ | |
"role": "system", | |
"content": "You are a helpful image conversational assistant, specialized in explaining about the monuments/landmarks of india. Give answer in points and in detail but dont hallucinate." | |
+ f"The caption of the image is: {caption}", | |
}, | |
*st.session_state["chat_history"], | |
] | |
# Simulate character-by-character response | |
response = client.chat.completions.create( | |
model="llama-3.1-8b-instant", messages=messages | |
) | |
assistant_response = response.choices[0].message.content | |
# Character-by-character output simulation | |
with st.chat_message("assistant"): | |
response_container = st.empty() # Placeholder for response | |
response_text = "" | |
for char in assistant_response: | |
response_text += char | |
time.sleep(0.005) # Adjust speed of character display | |
response_container.markdown(response_text) | |
# Append full response after display | |
st.session_state["chat_history"].append( | |
{"role": "assistant", "content": assistant_response} | |
) | |