import gradio as gr import torch import io from PIL import Image from transformers import ( AutoImageProcessor, AutoTokenizer, AutoModelForCausalLM, ) import numpy as np import ast model_root = "qihoo360/fg-clip2-base" model = AutoModelForCausalLM.from_pretrained(model_root,trust_remote_code=True) device = model.device tokenizer = AutoTokenizer.from_pretrained(model_root) image_processor = AutoImageProcessor.from_pretrained(model_root) def determine_max_value(image): w,h = image.size max_val = (w//16)*(h//16) if max_val > 784: return 1024 elif max_val > 576: return 784 elif max_val > 256: return 576 elif max_val > 128: return 256 else: return 128 def postprocess_result(probs, labels): pro_output = {labels[i]: probs[i] for i in range(len(labels))} return pro_output def Retrieval(image, candidate_labels, text_type): """ Takes an image and a comma-separated string of candidate labels, and returns the classification scores. """ if text_type is None: text_type = "long" print(text_type) image = image.convert("RGB") image_input = image_processor(images=image, max_num_patches=determine_max_value(image), return_tensors="pt").to(device) candidate_labels = [candidate_labels.lower() for candidate_labels in candidate_labels] if text_type=="long": max_length = 196 else: max_length = 64 caption_input = tokenizer(candidate_labels, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt").to(device) with torch.no_grad(): image_feature = model.get_image_features(**image_input) text_feature = model.get_text_features(**caption_input,walk_type=text_type) image_feature = image_feature / image_feature.norm(p=2, dim=-1, keepdim=True) text_feature = text_feature / text_feature.norm(p=2, dim=-1, keepdim=True) logits_per_image = image_feature @ text_feature.T logit_scale, logit_bias = model.logit_scale.to(text_feature.device), model.logit_bias.to(text_feature.device) logits_per_image = logits_per_image * logit_scale.exp() + logit_bias print(logits_per_image) # probs = torch.sigmoid(logits_per_image) probs = logits_per_image.softmax(dim=1) print(probs) results = probs[0].tolist() return results def infer(image, candidate_labels, text_type): # assert text_type in ["short","long", "box"] candidate_labels = ast.literal_eval(candidate_labels) fg_probs = Retrieval(image, candidate_labels,text_type) return postprocess_result(fg_probs,candidate_labels) with gr.Blocks() as demo: gr.Markdown("# FG-CLIP 2 Retrieval") gr.Markdown( "This app uses the FG-CLIP 2 model (qihoo360/fg-clip2-base) for retrieval on CPU :" ) with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil") text_input = gr.Textbox(label="Input a list of labels, example:['a','b','c']") text_type = gr.Textbox(label="form [short, long, box] select", value="long") run_button = gr.Button("Run Retrieval", visible=True) with gr.Column(): fg_output = gr.Label(label="FG-CLIP 2 Output", num_top_classes=11) examples = [ ["./000093.jpg", str([ "一个简约风格的卧室角落,黑色金属衣架上挂着多件米色和白色的衣物,下方架子放着两双浅色鞋子,旁边是一盆绿植,左侧可见一张铺有白色床单和灰色枕头的床。", "一个简约风格的卧室角落,黑色金属衣架上挂着多件红色和蓝色的衣物,下方架子放着两双黑色高跟鞋,旁边是一盆绿植,左侧可见一张铺有白色床单和灰色枕头的床。", "一个简约风格的卧室角落,黑色金属衣架上挂着多件米色和白色的衣物,下方架子放着两双运动鞋,旁边是一盆仙人掌,左侧可见一张铺有白色床单和灰色枕头的床。", "一个繁忙的街头市场,摊位上摆满水果,背景是高楼大厦,人们在喧闹中购物。" ] )], ["./000093.jpg", str([ "A minimalist-style bedroom corner with a black metal clothing rack holding several beige and white garments, two pairs of light-colored shoes on the shelf below, a potted green plant nearby, and to the left, a bed made with white sheets and gray pillows.", "A minimalist-style bedroom corner with a black metal clothing rack holding several red and blue garments, two pairs of black high heels on the shelf below, a potted green plant nearby, and to the left, a bed made with white sheets and gray pillows.", "A minimalist-style bedroom corner with a black metal clothing rack holding several beige and white garments, two pairs of sneakers on the shelf below, a potted cactus nearby, and to the left, a bed made with white sheets and gray pillows.", "A bustling street market with fruit-filled stalls, skyscrapers in the background, and people shopping amid the noise and activity." ] )], ] gr.Examples( examples=examples, inputs=[image_input, text_input, text_type], ) run_button.click(fn=infer, inputs=[image_input, text_input, text_type], outputs=fg_output) # demo.launch(server_name="0.0.0.0", server_port=7861, share=True) demo.launch()