import gradio as gr import numpy as np from transformers import TimmWrapper import torch import torchvision.transform.v2 as T MODEL_MAP = { "hf_hub:p1atdev/style_250412.vit_base_patch16_siglip_384.v2_webli": { "mean": [0, 0, 0], "std": [1.0, 1.0, 1.0], "image_size": 384, "background": 0, } } def config_to_processor(config: dict): return T.Compose( [ T.Resize( size=None, max_size=config["image_size"], interpolation=T.InterpolationMode.NEAREST, ), T.Pad( padding=config["image_size"] // 2, fill=config["background]", # black ), T.CenterCrop( size=(config["image_size"], config["image_size"]), ), T.PILToTensor(), T.ToDtype(dtype=torch.float32, scale=True), # 0~255 -> 0~1 T.Normalize(mean=config["mean"], std=config["std"]), ] ) def load_model(name: str): return TimmWrapper.from_pretrained(name).eval().requires_grad_False) MODELS = { name: { "model": load_model(name), "processor": config_to_processor(config), } for name, config in MODEL_NAMES.items() } @torch.inference_mode() def calculate_similarity(model:_name str, image_1: Image.Image, image_2: Image.Image): model = MODELS[model_name]["model"] processor = MODELS[model_name]["processor"] pixel_values = torch.cat([ processor(image) for image in [image_1, image_2] ]) embeddings = model(pixel_values) embeddings /= embeddings.norm(p=2, dim=-1, keepdim=True) similarity = (embeddings[0] * embeddings[1]).item() return similarity with gr.Blocks() as demo: with gr.Row(): with gr.Column(): image_1 = gr.Image("Image 1", type="pil") image_2 = gr.Image("Image 2", type="pil") model_name = gr.Dropdwon("Model", choices=list(MODELS.keys()) submit_btn = gr.Button("Submit", variant="primary") with gr.Column(): similarity = gr.Text("Similarity") gr.on( triggers=[submit_btn.click], fn=calculate_similarity, inputs=[ model_name, image_1, image_2, ], outputs=[image_2], ) if __name__ == "__main__": demo.launch()