Soumyajit9979's picture
Upload 28 files
d17ca98 verified
import xml.etree.ElementTree as ET
from elasticsearch import Elasticsearch, helpers
import torch
from transformers import CLIPProcessor, CLIPModel
import numpy as np
from server.utils.database import get_db
from server.utils.model import get_clip_model
from server.models.database import DocumentModel
from server.crud.patent_data import search_data
clip_model, clip_processor = get_clip_model()
def search_data_from_database(query=None, image_path=None,top_k=5,db=get_db(), index_name="patents"):
if query:
inputs = clip_processor(text=[query], return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
embedding = clip_model.get_text_features(**inputs)[0].cpu().numpy().astype(np.float32).tolist()
elif image_path:
from PIL import Image
image = Image.open(image_path).convert("RGB")
inputs = clip_processor(images=image, return_tensors="pt")
with torch.no_grad():
embedding = clip_model.get_image_features(**inputs)[0].cpu().numpy().astype(np.float32).tolist()
else:
return []
return search_data(
embedding=embedding,
top_k=top_k,
db=db,
index_name=index_name
)