Spaces:
Sleeping
Sleeping
File size: 4,221 Bytes
9f68e7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import base64
import os
from io import BytesIO
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
from img2art_search.data.dataset import ImageRetrievalDataset
from img2art_search.data.transforms import transform
from img2art_search.models.model import ViTImageSearchModel
from img2art_search.utils import (
get_or_create_pinecone_index,
get_pinecone_client,
inverse_transform_img,
)
def extract_embedding(image_data_batch, fine_tuned_model):
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
image_data_batch = image_data_batch.to(DEVICE)
with torch.no_grad():
embeddings = fine_tuned_model(image_data_batch).cpu().numpy()
return embeddings
def load_fine_tuned_model():
fine_tuned_model = ViTImageSearchModel()
fine_tuned_model.load_state_dict(torch.load("models/model.pth"))
fine_tuned_model.eval()
return fine_tuned_model
def create_gallery(
img_dataset: ImageRetrievalDataset,
fine_tuned_model: ViTImageSearchModel,
save: bool = True,
) -> list:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 4
fine_tuned_model.to(DEVICE)
gallery_embeddings = []
gallery_dataloader = DataLoader(
img_dataset, batch_size=batch_size, num_workers=1, shuffle=False
)
pc = get_pinecone_client()
gallery_index = get_or_create_pinecone_index(pc)
try:
count = 0
for img_data, _, img_name, _ in tqdm(gallery_dataloader):
data_objects = []
batch_embedding = extract_embedding(img_data, fine_tuned_model)
gallery_embeddings.append(batch_embedding)
for idx, embedding in enumerate(batch_embedding):
image = Image.fromarray(
inverse_transform_img(img_data[idx]).numpy().astype("uint8"), "RGB"
)
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
data_objects.append(
{
"id": str(count),
"values": embedding.tolist(),
"metadata": {
"image": img_base64,
"name": img_name[idx]
.split("/")[-1]
.replace(".jpg", "")
.replace(".jpeg", "")
.replace(".png", "")
.replace(".JPG", "")
.replace(".JPEG", "")
.replace("-", " ")
.replace("_", " - ")
.title(),
},
}
)
count += 1
gallery_index.upsert(vectors=data_objects)
except Exception as e:
print(f"Error creating gallery: {e}")
if save:
np.save("models/embeddings", gallery_embeddings)
return gallery_embeddings
def search_image(query_image_path: str, k: int = 4) -> tuple:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
fine_tuned_model = load_fine_tuned_model()
fine_tuned_model.to(DEVICE)
query_embedding = extract_embedding(query_image_path, fine_tuned_model)
pc = get_pinecone_client()
index = get_or_create_pinecone_index(pc)
response = index.query(
vector=[query_embedding.tolist()[0]], top_k=k, include_metadata=True
)
distances = []
results = []
for obj in response["matches"]:
result = base64.b64decode(obj.metadata["image"])
results.append(result)
distances.append(
str(round(obj["score"], 2) * 100) + " " + str(obj.metadata["name"])
)
return results, distances
def create_gallery_embeddings(folder: str) -> None:
x = np.array([f"{folder}/{file}" for file in os.listdir(folder)])
gallery_data = np.array([x, x])
gallery_dataset = ImageRetrievalDataset(gallery_data, transform=transform)
fine_tuned_model = load_fine_tuned_model()
create_gallery(gallery_dataset, fine_tuned_model)
|