Spaces:
Sleeping
Sleeping
from transformers import ViTModel, ViTImageProcessor | |
from PIL import Image, ImageOps | |
import gradio as gr | |
import torch | |
from datasets import Dataset | |
from torch.nn import CosineSimilarity | |
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") | |
image_encoder = ViTModel.from_pretrained("model/image_encoder/epoch_29").eval() | |
scribble_encoder = ViTModel.from_pretrained("model/scibble_encoder/epoch_29").eval() | |
candidates: Dataset = None | |
cosinesimilarity = CosineSimilarity() | |
def load_candidates(candidate_dir, progress=gr.Progress()): | |
def preprocess(examples): | |
images = [image for image in examples["image"]] | |
examples["image_embedding"] = image_encoder(image_processor(images, return_tensors="pt")["pixel_values"])["pooler_output"] | |
progress.update(len(images)) | |
return examples | |
dataset = [dict(image=Image.open(tempfile.name).convert("RGB").resize((224, 224))) for tempfile in progress.tqdm(candidate_dir)] | |
dataset = Dataset.from_list(dataset) | |
progress.tqdm(dataset) | |
with torch.no_grad(): | |
dataset = dataset.map(preprocess, batched=True, batch_size=1) | |
return dataset | |
def load_candidates_in_cache(candidate_files): | |
global candidates | |
candidates = load_candidates(candidate_files) | |
return [f.name for f in candidate_files] | |
def scribble_matching(input_img: Image): | |
input_img = ImageOps.invert(input_img) | |
scribble = input_img | |
scribble_embedding = scribble_encoder(image_processor(scribble, return_tensors="pt")["pixel_values"])["pooler_output"].to("cpu") | |
image_embeddings = torch.tensor(candidates["image_embedding"], dtype=torch.float32) | |
sim = cosinesimilarity(scribble_embedding, image_embeddings) | |
predicts = torch.topk(sim, k=15) | |
output_imgs = candidates[predicts.indices.tolist()]["image"] | |
labels = predicts.values.tolist() | |
labels = [f"{label:.3f}" for label in labels] | |
return list(zip([input_img] + output_imgs, ["preview"] + labels)) | |
def main(): | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
input_img = gr.Image(type="pil", label="scribble", height=512, width=512, source="canvas", tool="color-sketch", brush_radius=10) | |
prediction_gallery = gr.Gallery(min_width=512, columns=4, show_label=True) | |
with gr.Row(): | |
candidate_dir = gr.File(file_count="directory", min_width=300, height=300) | |
load_candidates_btn = gr.Button("Load", variant="secondary", size="sm") | |
btn = gr.Button("Scribble Matching", variant="primary") | |
load_candidates_btn.click(fn=load_candidates_in_cache, inputs=[candidate_dir], outputs=candidate_dir) | |
btn.click(fn=scribble_matching, inputs=[input_img], outputs=[prediction_gallery]) | |
demo.queue().launch() | |
if __name__ == "__main__": | |
main() |