File size: 2,674 Bytes
231b872 51c49bc 44fb620 dbe0dd0 44fb620 6139662 2ee994f 51c49bc 26f855a 2795ce6 26f855a 44fb620 26f855a 2ee994f 44fb620 51c49bc 6139662 26f855a 6139662 3885e21 26f855a 44fb620 26f855a 2ee994f 3885e21 6139662 51c49bc 3885e21 8434b5d 6139662 3885e21 2795ce6 3885e21 44fb620 3885e21 26f855a 3885e21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
# import pandas as pd
import numpy as np
import torch
import os
import cv2
from transformers import AutoModelForImageClassification, AutoConfig
import logging
from pathlib import Path
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from all_models import models
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def image_preprocessing(image):
try:
images = []
for i in image:
binary_image = i
binary_image = cv2.resize(binary_image, (224, 224))
binary_image = cv2.merge([binary_image, binary_image, binary_image])
binary_image = binary_image/255
binary_image = torch.from_numpy(binary_image)
images.append(binary_image)
return images
except Exception as e:
logger.error(f"Error in image_preprocessing: {str(e)}")
return None
def predict_image(images):
try:
# Get model instance from singleton
model, processor = models.get_vit_model()
preprocessed_img = image_preprocessing(images)
if preprocessed_img is None:
logger.error("Image preprocessing failed")
return None
images_tensor = torch.stack(preprocessed_img)
images_tensor = images_tensor.permute(0, 3, 1, 2)
with torch.no_grad():
predictions = model(images_tensor).logits
if torch.cuda.is_available():
predictions = predictions.cpu()
predictions = predictions.numpy()
return predictions
except Exception as e:
logger.error(f"Error in predict_image: {str(e)}")
return None
finally:
# Release model reference
models.release_vit_model()
def struck_images(word_images):
try:
predictions = predict_image(word_images)
if predictions is None:
logger.warning("Predictions failed, processing without model")
return word_images
not_struck = []
for i in range(len(predictions)):
if predictions[i].argmax() == 0: # Assuming 0 is the "not struck" class
not_struck.append(word_images[i])
if not not_struck:
logger.warning("No non-struck images found, returning all images")
return word_images
return not_struck
except Exception as e:
logger.error(f"Error in struck_images: {str(e)}")
return word_images # Return all images on error
|