Spaces:
Sleeping
Sleeping
File size: 6,504 Bytes
7556f1f ff66b1e 4782220 4abefbb 4782220 5cee983 ff66b1e c565ca4 ff66b1e 05531cd 7b3b5c5 4782220 1d69203 4782220 1d69203 4782220 1d69203 4782220 1d69203 4782220 1d69203 4782220 1d69203 4782220 7b3b5c5 4782220 1d69203 4782220 05531cd 4782220 05531cd fa17863 d97a6d1 4782220 a334df7 24aec06 a334df7 24aec06 4782220 002577a e028761 4782220 565bf9a 002577a 4782220 a334df7 4782220 a334df7 4782220 145279e a334df7 4782220 145279e 4782220 a334df7 145279e 2078ff6 145279e 24aec06 05531cd 145279e a334df7 ca8057c a334df7 ca8057c 1e41130 145279e 1e41130 3009eeb 087dcca a9b5e17 145279e a334df7 4782220 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import os
os.system("pip install fairseq2 --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/pt2.6.0/cu124 -q")
from huggingface_hub import hf_hub_download
import gradio as gr
import torch
import requests
from PIL import Image
from transformers import SiglipImageProcessor, SiglipVisionModel
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
import torch.nn as nn
import torch.nn.functional as F
from io import BytesIO
from transformers.image_utils import load_image
cos = nn.CosineSimilarity()
model_path = hf_hub_download(
repo_id="Sibgat-Ul/SONAR-Image_enc",
filename="best_sonar.pth",
repo_type="model"
)
language_mapping = {
"English": "eng_Latn",
"Bengali": "ben_Beng",
"French": "fra_Latn"
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -------- Load Image Encoder --------
class SonarImageEnc(nn.Module):
def __init__(self, path="google/siglip2-base-patch16-384", initial_temperature=0.07):
super().__init__()
self.model = SiglipVisionModel.from_pretrained(path, torch_dtype="auto")
for param in self.model.parameters():
param.requires_grad = False
self.projection = nn.Sequential(
nn.Linear(self.model.config.hidden_size, 2048),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(2048, 1024),
nn.LayerNorm(1024, eps=1e-5),
)
for param in self.projection.parameters():
param.requires_grad = True
self.temp_s = nn.Parameter(torch.log(torch.tensor(10.0)))
self.bias = nn.Parameter(torch.tensor(-10.0))
self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1.0) / initial_temperature))
def forward(self, pixel_values):
vision_outputs = self.model(pixel_values=pixel_values)
pooled_output = vision_outputs.pooler_output
embeddings = self.projection(pooled_output)
self.logit_scale.data.clamp_(
min=torch.log(torch.tensor(1.0).to(device) / torch.tensor(0.001).to(device)),
max=torch.log(torch.tensor(1.0).to(device) / torch.tensor(100.0).to(device))
)
return embeddings, torch.exp(self.logit_scale), torch.exp(self.temp_s), self.bias
# Load processor and models
processor = SiglipImageProcessor.from_pretrained("google/siglip2-base-patch16-384")
t2t_model_emb = TextToEmbeddingModelPipeline(
encoder="text_sonar_basic_encoder",
tokenizer="text_sonar_basic_encoder",
device=device,
dtype=torch.float16,
)
img_encoder = SonarImageEnc().to(device).eval()
img_encoder.load_state_dict(torch.load(model_path, map_location=device))
# -------- Similarity Scoring --------
def compute_similarity(
image, image_url,
option_a, option_b, option_c, option_d,
lang_opt_a, lang_opt_b, lang_opt_c, lang_opt_d
):
if not image:
try:
headers = {
"User-Agent": "Mozilla/5.0"
}
response = requests.get(image_url, headers=headers)
response.raise_for_status()
image = Image.open(BytesIO(response.content)).convert("RGB")
except Exception as e:
return None, {"Error": f"Image could not be loaded: {str(e)}"}
# Preprocess image
inputs = processor(image, return_tensors="pt").to(device)
with torch.no_grad():
image_emb, _, _, _ = img_encoder(inputs.pixel_values)
image_emb = image_emb.to(device, torch.float16)
# Map languages
lang_codes = [
language_mapping[lang_opt_a],
language_mapping[lang_opt_b],
language_mapping[lang_opt_c],
language_mapping[lang_opt_d],
]
texts = [option_a, option_b, option_c, option_d]
# Get embeddings per option with corresponding language
text_embeddings = []
for text, lang in zip(texts, lang_codes):
emb = t2t_model_emb.predict([text], source_lang=lang)
text_embeddings.append(emb)
text_embeddings = torch.cat(text_embeddings, dim=0).to(device)
scores = cos(image_emb, text_embeddings)
results = {
f"Option {chr(65+i)}": round(score.item(), 3)
for i, score in enumerate(scores)
}
results = {
k: f"{round(v * 100, 2)}%"
for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
}
return image, results
# -------- Gradio UI --------
with gr.Blocks(fill_height=True) as demo:
gr.Markdown("## π SONAR: Image-Text Similarity Scorer")
gr.Markdown("#### Upload an Image or provide an URL.")
with gr.Row():
with gr.Column():
image_url = gr.Textbox(label="Image URL", value="http://images.cocodataset.org/val2017/000000039769.jpg")
with gr.Row():
option_a = gr.Textbox(label="Option A", value="Two cats in a bed.")
lang_opt_a = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language")
with gr.Row():
option_b = gr.Textbox(label="Option B", value="Two cat with two remotes.")
lang_opt_b = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language")
with gr.Row():
option_c = gr.Textbox(label="Option C", value="Two remotes.")
lang_opt_c = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language")
with gr.Row():
option_d = gr.Textbox(label="Option D", value="Two cats.")
lang_opt_d = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language")
# language = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Select Language")
with gr.Column():
image_input = gr.Image(label="Upload an image", type="pil")
btn = gr.Button("Done")
with gr.Row():
img_output = gr.Image(label="Input Image", type="pil", width=300, height=300)
result_output = gr.JSON(label="Similarity Scores")
btn.click(
fn=compute_similarity,
inputs=[
image_input, image_url,
option_a, option_b, option_c, option_d,
lang_opt_a, lang_opt_b, lang_opt_c, lang_opt_d
],
outputs=[img_output, result_output]
)
demo.launch()
|