HaohuaLv's picture
Update app.py
b63fee5
from transformers import ViTModel, ViTImageProcessor
from PIL import Image, ImageOps
import gradio as gr
import torch
from datasets import Dataset
from torch.nn import CosineSimilarity
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
image_encoder = ViTModel.from_pretrained("model/image_encoder/epoch_29").eval()
scribble_encoder = ViTModel.from_pretrained("model/scibble_encoder/epoch_29").eval()
candidates: Dataset = None
cosinesimilarity = CosineSimilarity()
def load_candidates(candidate_dir, progress=gr.Progress()):
def preprocess(examples):
images = [image for image in examples["image"]]
examples["image_embedding"] = image_encoder(image_processor(images, return_tensors="pt")["pixel_values"])["pooler_output"]
progress.update(len(images))
return examples
dataset = [dict(image=Image.open(tempfile.name).convert("RGB").resize((224, 224))) for tempfile in progress.tqdm(candidate_dir)]
dataset = Dataset.from_list(dataset)
progress.tqdm(dataset)
with torch.no_grad():
dataset = dataset.map(preprocess, batched=True, batch_size=1)
return dataset
def load_candidates_in_cache(candidate_files):
global candidates
candidates = load_candidates(candidate_files)
return [f.name for f in candidate_files]
def scribble_matching(input_img: Image):
input_img = ImageOps.invert(input_img)
scribble = input_img
scribble_embedding = scribble_encoder(image_processor(scribble, return_tensors="pt")["pixel_values"])["pooler_output"].to("cpu")
image_embeddings = torch.tensor(candidates["image_embedding"], dtype=torch.float32)
sim = cosinesimilarity(scribble_embedding, image_embeddings)
predicts = torch.topk(sim, k=15)
output_imgs = candidates[predicts.indices.tolist()]["image"]
labels = predicts.values.tolist()
labels = [f"{label:.3f}" for label in labels]
return list(zip([input_img] + output_imgs, ["preview"] + labels))
def main():
with gr.Blocks() as demo:
with gr.Row():
input_img = gr.Image(type="pil", label="scribble", height=512, width=512, source="canvas", tool="color-sketch", brush_radius=10)
prediction_gallery = gr.Gallery(min_width=512, columns=4, show_label=True)
with gr.Row():
candidate_dir = gr.File(file_count="directory", min_width=300, height=300)
load_candidates_btn = gr.Button("Load", variant="secondary", size="sm")
btn = gr.Button("Scribble Matching", variant="primary")
load_candidates_btn.click(fn=load_candidates_in_cache, inputs=[candidate_dir], outputs=candidate_dir)
btn.click(fn=scribble_matching, inputs=[input_img], outputs=[prediction_gallery])
demo.queue().launch()
if __name__ == "__main__":
main()