import pandas as pd import numpy as np import tensorflow as tf import torch import os import cv2 from transformers import AutoModelForImageClassification def image_preprocessing(image): images=[] for i in image: binary_image = i binary_image = cv2.resize(binary_image, (224, 224)) binary_image = cv2.merge([binary_image, binary_image, binary_image]) binary_image = binary_image/255 binary_image = torch.from_numpy(binary_image) images.append(binary_image) return images def predict_image(image_path, model): preprocessed_img = image_preprocessing(image_path) images = torch.stack(preprocessed_img) images = images.permute(0, 3, 1, 2) predictions = model(images).logits.detach().numpy() return predictions model = AutoModelForImageClassification.from_pretrained("models/vit-base-beans") def struck_images(word__image): predictions = predict_image(word__image, model) not_struck =[] for i in range(len(predictions)): if predictions[i].argmax().item() == 0: not_struck.append(word__image[i]) # print(not_struck) return not_struck # struck_images()