yamanavijayavardhan's picture
Removed tensorflow import1
cf3f8ad
raw
history blame
1.17 kB
# import pandas as pd
import numpy as np
import torch
import os
import cv2
from transformers import AutoModelForImageClassification
def image_preprocessing(image):
images=[]
for i in image:
binary_image = i
binary_image = cv2.resize(binary_image, (224, 224))
binary_image = cv2.merge([binary_image, binary_image, binary_image])
binary_image = binary_image/255
binary_image = torch.from_numpy(binary_image)
images.append(binary_image)
return images
def predict_image(image_path, model):
preprocessed_img = image_preprocessing(image_path)
images = torch.stack(preprocessed_img)
images = images.permute(0, 3, 1, 2)
predictions = model(images).logits.detach().numpy()
return predictions
model = AutoModelForImageClassification.from_pretrained("models/vit-base-beans")
def struck_images(word__image):
predictions = predict_image(word__image, model)
not_struck =[]
for i in range(len(predictions)):
if predictions[i].argmax().item() == 0:
not_struck.append(word__image[i])
# print(not_struck)
return not_struck
# struck_images()