File size: 1,257 Bytes
d17ca98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import xml.etree.ElementTree as ET
from elasticsearch import Elasticsearch, helpers
import torch
from transformers import CLIPProcessor, CLIPModel
import numpy as np
from server.utils.database import get_db
from server.utils.model import get_clip_model
from server.models.database import DocumentModel
from server.crud.patent_data import search_data
clip_model, clip_processor = get_clip_model()
def search_data_from_database(query=None, image_path=None,top_k=5,db=get_db(), index_name="patents"):
if query:
inputs = clip_processor(text=[query], return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
embedding = clip_model.get_text_features(**inputs)[0].cpu().numpy().astype(np.float32).tolist()
elif image_path:
from PIL import Image
image = Image.open(image_path).convert("RGB")
inputs = clip_processor(images=image, return_tensors="pt")
with torch.no_grad():
embedding = clip_model.get_image_features(**inputs)[0].cpu().numpy().astype(np.float32).tolist()
else:
return []
return search_data(
embedding=embedding,
top_k=top_k,
db=db,
index_name=index_name
) |