File size: 5,350 Bytes
908c66a
22a9ea9
d9f9a8f
 
 
908c66a
 
 
5122672
22a9ea9
d9f9a8f
1334447
fa3d596
6a94503
908c66a
1334447
 
 
d9f9a8f
1334447
33c3ee8
07e9630
908c66a
fa3d596
07e9630
6a94503
 
1a1967b
22a9ea9
5122672
b5bce8a
 
5122672
 
 
b5bce8a
 
6a94503
b5bce8a
5122672
 
 
b5bce8a
6a94503
 
908c66a
5122672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1334447
 
fa3d596
 
 
 
 
 
22a9ea9
 
 
 
 
fa3d596
0cfbef9
 
22a9ea9
a3c4da8
22a9ea9
6a94503
 
 
 
 
e3605f4
a3c4da8
22a9ea9
 
1334447
22a9ea9
 
6a94503
 
 
22a9ea9
 
5122672
 
 
 
908c66a
fa3d596
 
d9f9a8f
22a9ea9
908c66a
fa3d596
 
b5bce8a
908c66a
b5bce8a
 
fa3d596
07e9630
1334447
253c5c2
 
 
 
6a94503
253c5c2
b5bce8a
0cfbef9
b5bce8a
253c5c2
 
 
 
 
0cfbef9
908c66a
07e9630
2b1972b
0cfbef9
d9f9a8f
22a9ea9
908c66a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
from PIL import Image
import requests
import torch
import io
import os
from huggingface_hub import login
import tempfile
from threading import Thread

# Log in to Hugging Face for gated model access
login(token=os.getenv("HF_TOKEN", "your_hf_token_here"))  # Replace with your HF token or set HF_TOKEN env variable

# Load model and processor
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
    model_id, device_map="auto", torch_dtype=torch.bfloat16
).eval()
processor = AutoProcessor.from_pretrained(model_id)

def process_input(chat_history, image_file, image_url, text_input):
    try:
        # Validate text input
        if not text_input:
            yield chat_history + [(None, "Error: Please enter a text prompt.")]
            return

        # Handle image input
        image_path = None
        if image_file is not None:
            image = Image.open(image_file).convert("RGB")
            with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
                image.save(tmp.name)
                image_path = tmp.name
        elif image_url:
            response = requests.get(image_url)
            response.raise_for_status()
            image = Image.open(io.BytesIO(response.content)).convert("RGB")
            with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
                image.save(tmp.name)
                image_path = tmp.name
        else:
            yield chat_history + [(text_input, "Error: Please provide an image file or a valid image URL.")]
            return

        # Prepare chat template (based on original code)
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": "You are a helpful assistant capable of analyzing images."}]
            },
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image_path},
                    {"type": "text", "text": text_input}
                ]
            }
        ]

        # Process inputs with chat template
        inputs = processor.apply_chat_template(
            messages, add_generation_prompt=True, tokenize=True,
            return_dict=True, return_tensors="pt"
        ).to(model.device, dtype=torch.bfloat16)

        # Debug: Verify input tensors
        if "pixel_values" not in inputs:
            yield chat_history + [(text_input, "Error: Image data not processed correctly (missing pixel_values).")]
            return
        # print(f"Input keys: {inputs.keys()}, pixel_values shape: {inputs['pixel_values'].shape}")  # Uncomment for debugging

        # Initialize streamer
        streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
        generation_kwargs = {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
            "pixel_values": inputs.get("pixel_values"),
            "max_new_tokens": 500,
            "do_sample": True,
        }

        # Update chat history with user message
        user_message = text_input
        bot_message = ""
        new_history = chat_history + [(user_message, bot_message)]
        yield new_history

        # Run generation in a separate thread for streaming
        with torch.inference_mode():
            thread = Thread(target=model.generate, kwargs={**generation_kwargs, "streamer": streamer})
            thread.start()

            for new_text in streamer:
                bot_message += new_text
                new_history[-1] = (user_message, bot_message)
                yield new_history

            thread.join()

        # Clean up temporary file
        if image_path and os.path.exists(image_path):
            os.remove(image_path)

    except Exception as e:
        import traceback
        yield chat_history + [(text_input, f"Error: {str(e)}\n{traceback.format_exc()}")]

# Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Multimodal Streaming Chat with Gemma-3-4b-it")
    gr.Markdown("Upload an image or provide an image URL, then enter a text prompt (e.g., 'Describe this image in detail').")

    with gr.Row():
        with gr.Column():
            image_file_input = gr.Image(label="Upload Image (Drag-and-Drop)", type="filepath")
            image_url_input = gr.Textbox(label="Or Enter Image URL", placeholder="e.g., https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/bee.jpg")
        with gr.Column():
            chatbot = gr.Chatbot(label="Chat with Gemma", height=400)
            text_input = gr.Textbox(
                label="Enter your prompt here",
                placeholder="e.g., Describe this image in detail",
                lines=2,
                submit_btn=True
            )

    #submit_button = gr.Button("Submit")

    text_input.submit(
        fn=process_input,
        inputs=[chatbot, image_file_input, image_url_input, text_input],
        outputs=chatbot
    )
    '''submit_button.click(
        fn=process_input,
        inputs=[chatbot, image_file_input, image_url_input, text_input],
        outputs=chatbot
    )'''

# Launch the app
demo.launch()