File size: 1,257 Bytes
d17ca98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import xml.etree.ElementTree as ET
from elasticsearch import Elasticsearch, helpers
import torch
from transformers import CLIPProcessor, CLIPModel
import numpy as np

from server.utils.database import get_db
from server.utils.model import get_clip_model

from server.models.database import DocumentModel
from server.crud.patent_data import search_data


clip_model, clip_processor = get_clip_model()


def search_data_from_database(query=None, image_path=None,top_k=5,db=get_db(), index_name="patents"):
    if query:
        inputs = clip_processor(text=[query], return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            embedding = clip_model.get_text_features(**inputs)[0].cpu().numpy().astype(np.float32).tolist()
    elif image_path:
        from PIL import Image
        image = Image.open(image_path).convert("RGB")
        inputs = clip_processor(images=image, return_tensors="pt")
        with torch.no_grad():
            embedding = clip_model.get_image_features(**inputs)[0].cpu().numpy().astype(np.float32).tolist()
    else:
        return []
    
    return search_data(
        embedding=embedding,
        top_k=top_k,
        db=db,
        index_name=index_name
    )