File size: 3,484 Bytes
957d93a
c7b1ee6
 
 
 
957d93a
e957893
cdd119a
 
 
 
 
 
 
c7b1ee6
957d93a
c7b1ee6
cdd119a
 
 
 
 
 
 
c7b1ee6
957d93a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7b1ee6
 
957d93a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7b1ee6
957d93a
 
 
c7b1ee6
957d93a
 
c7b1ee6
957d93a
 
 
 
4dd17a5
957d93a
 
 
 
cdd119a
 
957d93a
 
c7b1ee6
957d93a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
from transformers import AutoProcessor, Gemma3nForConditionalGeneration
from PIL import Image
import requests
import torch
import io
import os
from huggingface_hub import login



hf_token = os.environ.get("HF_TOKEN")
login(token=hf_token)


# Initialize the model and processor
model_id = "google/gemma-3n-e4b-it"
try:
    model = Gemma3nForConditionalGeneration.from_pretrained(
        model_id, device_map="auto", torch_dtype=torch.bfloat16
    ).eval()
    processor = AutoProcessor.from_pretrained(model_id)
except Exception as e:
    raise Exception(f"Failed to load model or processor: {str(e)}")

def process_inputs(image_input, image_url, text_prompt):
    """
    Process image (from file or URL) and text prompt to generate a response using the Gemma model.
    
    Args:
        image_input: Uploaded image file
        image_url: URL of an image
        text_prompt: Text input from the user
    
    Returns:
        Generated text response from the model
    """
    try:
        # Handle image input: prioritize uploaded image, then URL, then None
        image = None
        if image_input is not None:
            image = Image.open(image_input).convert("RGB")
        elif image_url:
            response = requests.get(image_url, stream=True)
            response.raise_for_status()
            image = Image.open(io.BytesIO(response.content)).convert("RGB")

        # Prepare messages for the model
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": "You are a helpful assistant."}]
            },
            {
                "role": "user",
                "content": []
            }
        ]

        # Add image to content if provided
        if image is not None:
            messages[1]["content"].append({"type": "image", "image": image})
        
        # Add text prompt if provided
        if text_prompt:
            messages[1]["content"].append({"type": "text", "text": text_prompt})
        else:
            return "Please provide a text prompt."

        # Process inputs using the processor
        inputs = processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
        ).to(model.device)

        input_len = inputs["input_ids"].shape[-1]

        # Generate response
        with torch.inference_mode():
            generation = model.generate(**inputs, max_new_tokens=500, do_sample=False)
            generation = generation[0][input_len:]

        # Decode and return the response
        decoded = processor.decode(generation, skip_special_tokens=True)
        return decoded

    except Exception as e:
        return f"Error: {str(e)}"

# Define the Gradio interface
iface = gr.Interface(
    fn=process_inputs,
    inputs=[
        gr.Image(type="filepath", label="Upload Image (optional)"),
        gr.Textbox(label="Image URL (optional)", placeholder="Enter image URL"),
        gr.Textbox(label="Text Prompt", placeholder="Enter your prompt here")
    ],
    outputs=gr.Textbox(label="Model Response"),
    title="Gemma-3 Multimodal App (Authenticated)",
    description="Upload an image or provide an image URL, and enter a text prompt to interact with the Gemma-3 model. Ensure you have authenticated with a valid Hugging Face access token.",
    allow_flagging="never"
)

# Launch the app
iface.launch()