File size: 6,504 Bytes
7556f1f
 
 
ff66b1e
4782220
 
 
 
 
 
 
 
 
4abefbb
4782220
5cee983
 
ff66b1e
 
c565ca4
 
ff66b1e
 
05531cd
 
 
 
 
 
7b3b5c5
 
4782220
 
 
 
1d69203
 
 
4782220
 
1d69203
4782220
 
 
 
1d69203
4782220
 
 
1d69203
 
 
 
 
 
 
 
4782220
1d69203
 
4782220
1d69203
4782220
7b3b5c5
 
4782220
1d69203
 
4782220
 
 
05531cd
4782220
 
 
 
 
 
05531cd
fa17863
d97a6d1
4782220
 
a334df7
 
 
 
 
 
24aec06
 
 
a334df7
24aec06
 
 
 
 
 
4782220
 
002577a
e028761
4782220
565bf9a
002577a
4782220
a334df7
 
 
 
 
 
 
4782220
a334df7
 
 
 
 
 
 
 
4782220
145279e
a334df7
 
 
 
 
 
 
 
 
 
4782220
145279e
4782220
a334df7
145279e
2078ff6
145279e
24aec06
05531cd
145279e
 
 
a334df7
ca8057c
 
 
 
 
 
 
a334df7
ca8057c
 
 
 
 
 
 
 
 
1e41130
145279e
1e41130
 
3009eeb
087dcca
a9b5e17
 
145279e
a334df7
 
 
 
 
 
 
 
 
 
4782220
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import os
os.system("pip install fairseq2 --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/pt2.6.0/cu124 -q")

from huggingface_hub import hf_hub_download
import gradio as gr
import torch
import requests
from PIL import Image
from transformers import SiglipImageProcessor, SiglipVisionModel
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
import torch.nn as nn
import torch.nn.functional as F
from io import BytesIO
from transformers.image_utils import load_image

cos = nn.CosineSimilarity()

model_path = hf_hub_download(
    repo_id="Sibgat-Ul/SONAR-Image_enc",
    filename="best_sonar.pth",
    repo_type="model"
)

language_mapping = {
    "English": "eng_Latn",
    "Bengali": "ben_Beng",
    "French": "fra_Latn"
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------- Load Image Encoder --------
class SonarImageEnc(nn.Module):
    def __init__(self, path="google/siglip2-base-patch16-384", initial_temperature=0.07):
        super().__init__()
        
        self.model = SiglipVisionModel.from_pretrained(path, torch_dtype="auto")
        
        for param in self.model.parameters():
            param.requires_grad = False
        
        self.projection = nn.Sequential(
            nn.Linear(self.model.config.hidden_size, 2048),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(2048, 1024), 
            nn.LayerNorm(1024, eps=1e-5),
        )

        for param in self.projection.parameters():
            param.requires_grad = True

        self.temp_s = nn.Parameter(torch.log(torch.tensor(10.0)))
        self.bias = nn.Parameter(torch.tensor(-10.0))
    
        self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1.0) / initial_temperature))
        
    def forward(self, pixel_values):
        vision_outputs = self.model(pixel_values=pixel_values)
        pooled_output = vision_outputs.pooler_output
        embeddings = self.projection(pooled_output)
        
        self.logit_scale.data.clamp_(
            min=torch.log(torch.tensor(1.0).to(device) / torch.tensor(0.001).to(device)),
            max=torch.log(torch.tensor(1.0).to(device) / torch.tensor(100.0).to(device))
        )

        return embeddings, torch.exp(self.logit_scale), torch.exp(self.temp_s), self.bias

# Load processor and models
processor = SiglipImageProcessor.from_pretrained("google/siglip2-base-patch16-384")

t2t_model_emb = TextToEmbeddingModelPipeline(
    encoder="text_sonar_basic_encoder",
    tokenizer="text_sonar_basic_encoder",
    device=device,
    dtype=torch.float16,
)

img_encoder = SonarImageEnc().to(device).eval()
img_encoder.load_state_dict(torch.load(model_path, map_location=device))

# -------- Similarity Scoring --------
def compute_similarity(
    image, image_url, 
    option_a, option_b, option_c, option_d,
    lang_opt_a, lang_opt_b, lang_opt_c, lang_opt_d
):
    
    if not image:
        try:
            headers = {
                "User-Agent": "Mozilla/5.0"
            }
            response = requests.get(image_url, headers=headers)
            response.raise_for_status()
            image = Image.open(BytesIO(response.content)).convert("RGB")
        except Exception as e:
            return None, {"Error": f"Image could not be loaded: {str(e)}"}

    # Preprocess image
    inputs = processor(image, return_tensors="pt").to(device)
    
    with torch.no_grad():
        image_emb, _, _, _ = img_encoder(inputs.pixel_values)
        image_emb = image_emb.to(device, torch.float16)

    # Map languages
    lang_codes = [
        language_mapping[lang_opt_a],
        language_mapping[lang_opt_b],
        language_mapping[lang_opt_c],
        language_mapping[lang_opt_d],
    ]
    texts = [option_a, option_b, option_c, option_d]

    # Get embeddings per option with corresponding language
    text_embeddings = []
    for text, lang in zip(texts, lang_codes):
        emb = t2t_model_emb.predict([text], source_lang=lang)
        text_embeddings.append(emb)

    text_embeddings = torch.cat(text_embeddings, dim=0).to(device)

    scores = cos(image_emb, text_embeddings)

    results = {
        f"Option {chr(65+i)}": round(score.item(), 3)
        for i, score in enumerate(scores)
    }

    results = {
        k: f"{round(v * 100, 2)}%"
        for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
    }

    return image, results


# -------- Gradio UI --------
with gr.Blocks(fill_height=True) as demo:
    gr.Markdown("## πŸ” SONAR: Image-Text Similarity Scorer")
    gr.Markdown("#### Upload an Image or provide an URL.")
    
    with gr.Row():
        with gr.Column():
            image_url = gr.Textbox(label="Image URL", value="http://images.cocodataset.org/val2017/000000039769.jpg")

            with gr.Row():
                option_a = gr.Textbox(label="Option A", value="Two cats in a bed.")
                lang_opt_a = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language")
                    
            with gr.Row():    
                option_b = gr.Textbox(label="Option B", value="Two cat with two remotes.")
                lang_opt_b = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language")
                
            with gr.Row():
                option_c = gr.Textbox(label="Option C", value="Two remotes.")
                lang_opt_c = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language")
                
            with gr.Row():
                option_d = gr.Textbox(label="Option D", value="Two cats.")
                lang_opt_d = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language")
                
                # language = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Select Language")

        with gr.Column():
            image_input = gr.Image(label="Upload an image", type="pil")
            btn = gr.Button("Done")

            with gr.Row():
                img_output = gr.Image(label="Input Image", type="pil", width=300, height=300)
                result_output = gr.JSON(label="Similarity Scores")

    btn.click(
        fn=compute_similarity,
        inputs=[
            image_input, image_url,
            option_a, option_b, option_c, option_d,
            lang_opt_a, lang_opt_b, lang_opt_c, lang_opt_d
        ],
        outputs=[img_output, result_output]
    )


demo.launch()