yamanavijayavardhan's picture
Initial upload of answer grading application
51c49bc
raw
history blame
1.19 kB
import pandas as pd
import numpy as np
import tensorflow as tf
import torch
import os
import cv2
from transformers import AutoModelForImageClassification
def image_preprocessing(image):
images=[]
for i in image:
binary_image = i
binary_image = cv2.resize(binary_image, (224, 224))
binary_image = cv2.merge([binary_image, binary_image, binary_image])
binary_image = binary_image/255
binary_image = torch.from_numpy(binary_image)
images.append(binary_image)
return images
def predict_image(image_path, model):
preprocessed_img = image_preprocessing(image_path)
images = torch.stack(preprocessed_img)
images = images.permute(0, 3, 1, 2)
predictions = model(images).logits.detach().numpy()
return predictions
model = AutoModelForImageClassification.from_pretrained("models/vit-base-beans")
def struck_images(word__image):
predictions = predict_image(word__image, model)
not_struck =[]
for i in range(len(predictions)):
if predictions[i].argmax().item() == 0:
not_struck.append(word__image[i])
# print(not_struck)
return not_struck
# struck_images()