File size: 3,317 Bytes
422cae6
46f1f02
 
31a5b8f
 
422cae6
31a5b8f
46f1f02
 
 
31a5b8f
 
 
 
7cb1958
46f1f02
 
 
 
 
 
 
b970dfa
31a5b8f
 
 
 
 
 
 
 
 
 
 
 
b970dfa
d94240d
31a5b8f
 
b970dfa
 
31a5b8f
 
d94240d
31a5b8f
 
 
 
d94240d
 
422cae6
31a5b8f
d94240d
e4f79b8
31a5b8f
 
 
e4f79b8
 
278623e
e4f79b8
31a5b8f
 
b29c175
31a5b8f
259faf6
31a5b8f
259faf6
d94240d
31a5b8f
 
 
422cae6
e4f79b8
31a5b8f
 
b29c175
31a5b8f
278623e
422cae6
31a5b8f
278623e
 
422cae6
d94240d
 
278623e
 
 
422cae6
31a5b8f
422cae6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
from groq import Groq
from config import GROQ_API_KEY, MODEL_NAME

# === Load BLIP image captioning model ===
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

# === Initialize Groq client ===
client = Groq(api_key=GROQ_API_KEY)

# === Function to describe image using BLIP ===
def describe_image(image: Image.Image):
    try:
        inputs = processor(images=image, return_tensors="pt")
        out = model.generate(**inputs)
        description = processor.decode(out[0], skip_special_tokens=True)
        return description
    except Exception as e:
        return f"Error describing the image: {e}"

# === Generate response from Groq ===
def generate_response(user_input):
    try:
        response = client.chat.completions.create(
            messages=[{"role": "user", "content": user_input}],
            model=MODEL_NAME,
        )
        return response.choices[0].message.content
    except Exception as e:
        return f"Error from Groq: {e}"

# === Chatbot logic ===
def chat(user_input, chat_history, image):
    try:
        ai_reply = generate_response(user_input)

        if image is not None:
            image_description = describe_image(image)
            ai_reply += f"\n\n[Image Description]: {image_description}"

        chat_history.append(("User", user_input))
        chat_history.append(("AI", ai_reply))

        formatted = "\n".join([f"{role}: {msg}" for role, msg in chat_history])
        return formatted, chat_history
    except Exception as e:
        return f"Error: {e}", chat_history

# === Gradio Interface ===
with gr.Blocks(css="""
    .gradio-container {
        font-family: 'Segoe UI', sans-serif;
        background-color: #f5f5f5;
        padding: 20px;
    }
    #chatbox {
        height: 300px;
        overflow-y: auto;
        background-color: #ffffff;
        border: 1px solid #ccc;
        border-radius: 10px;
        padding: 15px;
        font-size: 14px;
        line-height: 1.5;
    }
""") as demo:

    gr.Markdown("## 🤖 **Groq-powered Chatbot with Image Understanding (BLIP)**")
    gr.Markdown("Chat with the bot or upload an image to get a caption.")

    with gr.Column():
        user_input = gr.Textbox(label="Your Message", placeholder="Ask something...", lines=2)
        submit_button = gr.Button("Send")
        clear_button = gr.Button("Clear Chat")

        chatbot_output = gr.Textbox(label="Chat History", lines=12, interactive=False, elem_id="chatbox")

        image_input = gr.Image(label="Upload an Image", type="pil", elem_id="image-upload")
        upload_button = gr.Button("Describe Image")
        image_caption = gr.Textbox(label="Image Description", interactive=False)

    chat_history = gr.State([])

    submit_button.click(fn=chat, inputs=[user_input, chat_history, image_input], outputs=[chatbot_output, chat_history])
    clear_button.click(fn=lambda: ("", []), inputs=[], outputs=[chatbot_output, chat_history])
    upload_button.click(fn=describe_image, inputs=[image_input], outputs=[image_caption])

# === Launch the app ===
if __name__ == "__main__":
    demo.launch()