|
|
|
import numpy as np |
|
import torch |
|
import os |
|
import cv2 |
|
from transformers import AutoModelForImageClassification |
|
|
|
def image_preprocessing(image): |
|
images=[] |
|
for i in image: |
|
binary_image = i |
|
binary_image = cv2.resize(binary_image, (224, 224)) |
|
binary_image = cv2.merge([binary_image, binary_image, binary_image]) |
|
binary_image = binary_image/255 |
|
binary_image = torch.from_numpy(binary_image) |
|
images.append(binary_image) |
|
return images |
|
|
|
def predict_image(image_path, model): |
|
preprocessed_img = image_preprocessing(image_path) |
|
images = torch.stack(preprocessed_img) |
|
images = images.permute(0, 3, 1, 2) |
|
predictions = model(images).logits.detach().numpy() |
|
return predictions |
|
|
|
|
|
model = AutoModelForImageClassification.from_pretrained("models/vit-base-beans") |
|
|
|
def struck_images(word__image): |
|
|
|
|
|
predictions = predict_image(word__image, model) |
|
|
|
not_struck =[] |
|
for i in range(len(predictions)): |
|
if predictions[i].argmax().item() == 0: |
|
|
|
not_struck.append(word__image[i]) |
|
|
|
|
|
return not_struck |
|
|
|
|
|
|