File size: 4,221 Bytes
9f68e7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import base64
import os
from io import BytesIO

import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm

from img2art_search.data.dataset import ImageRetrievalDataset
from img2art_search.data.transforms import transform
from img2art_search.models.model import ViTImageSearchModel
from img2art_search.utils import (
    get_or_create_pinecone_index,
    get_pinecone_client,
    inverse_transform_img,
)


def extract_embedding(image_data_batch, fine_tuned_model):
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    image_data_batch = image_data_batch.to(DEVICE)
    with torch.no_grad():
        embeddings = fine_tuned_model(image_data_batch).cpu().numpy()
    return embeddings


def load_fine_tuned_model():
    fine_tuned_model = ViTImageSearchModel()
    fine_tuned_model.load_state_dict(torch.load("models/model.pth"))
    fine_tuned_model.eval()
    return fine_tuned_model


def create_gallery(
    img_dataset: ImageRetrievalDataset,
    fine_tuned_model: ViTImageSearchModel,
    save: bool = True,
) -> list:
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    batch_size = 4
    fine_tuned_model.to(DEVICE)
    gallery_embeddings = []
    gallery_dataloader = DataLoader(
        img_dataset, batch_size=batch_size, num_workers=1, shuffle=False
    )

    pc = get_pinecone_client()
    gallery_index = get_or_create_pinecone_index(pc)
    try:
        count = 0
        for img_data, _, img_name, _ in tqdm(gallery_dataloader):
            data_objects = []
            batch_embedding = extract_embedding(img_data, fine_tuned_model)
            gallery_embeddings.append(batch_embedding)
            for idx, embedding in enumerate(batch_embedding):
                image = Image.fromarray(
                    inverse_transform_img(img_data[idx]).numpy().astype("uint8"), "RGB"
                )
                buffered = BytesIO()
                image.save(buffered, format="JPEG")
                img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
                data_objects.append(
                    {
                        "id": str(count),
                        "values": embedding.tolist(),
                        "metadata": {
                            "image": img_base64,
                            "name": img_name[idx]
                            .split("/")[-1]
                            .replace(".jpg", "")
                            .replace(".jpeg", "")
                            .replace(".png", "")
                            .replace(".JPG", "")
                            .replace(".JPEG", "")
                            .replace("-", " ")
                            .replace("_", " - ")
                            .title(),
                        },
                    }
                )
                count += 1
            gallery_index.upsert(vectors=data_objects)
    except Exception as e:
        print(f"Error creating gallery: {e}")

    if save:
        np.save("models/embeddings", gallery_embeddings)
    return gallery_embeddings


def search_image(query_image_path: str, k: int = 4) -> tuple:
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    fine_tuned_model = load_fine_tuned_model()
    fine_tuned_model.to(DEVICE)
    query_embedding = extract_embedding(query_image_path, fine_tuned_model)
    pc = get_pinecone_client()
    index = get_or_create_pinecone_index(pc)
    response = index.query(
        vector=[query_embedding.tolist()[0]], top_k=k, include_metadata=True
    )
    distances = []
    results = []
    for obj in response["matches"]:
        result = base64.b64decode(obj.metadata["image"])
        results.append(result)
        distances.append(
            str(round(obj["score"], 2) * 100) + " " + str(obj.metadata["name"])
        )

    return results, distances


def create_gallery_embeddings(folder: str) -> None:
    x = np.array([f"{folder}/{file}" for file in os.listdir(folder)])
    gallery_data = np.array([x, x])
    gallery_dataset = ImageRetrievalDataset(gallery_data, transform=transform)
    fine_tuned_model = load_fine_tuned_model()
    create_gallery(gallery_dataset, fine_tuned_model)