|
|
|
import numpy as np |
|
import torch |
|
import os |
|
import cv2 |
|
from transformers import AutoModelForImageClassification, AutoConfig |
|
import logging |
|
from pathlib import Path |
|
import sys |
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
from all_models import models |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
def image_preprocessing(image): |
|
try: |
|
images = [] |
|
for i in image: |
|
binary_image = i |
|
binary_image = cv2.resize(binary_image, (224, 224)) |
|
binary_image = cv2.merge([binary_image, binary_image, binary_image]) |
|
binary_image = binary_image/255 |
|
binary_image = torch.from_numpy(binary_image) |
|
images.append(binary_image) |
|
return images |
|
|
|
except Exception as e: |
|
logger.error(f"Error in image_preprocessing: {str(e)}") |
|
return None |
|
|
|
def predict_image(images): |
|
try: |
|
|
|
model, processor = models.get_vit_model() |
|
|
|
preprocessed_img = image_preprocessing(images) |
|
if preprocessed_img is None: |
|
logger.error("Image preprocessing failed") |
|
return None |
|
|
|
images_tensor = torch.stack(preprocessed_img) |
|
images_tensor = images_tensor.permute(0, 3, 1, 2) |
|
|
|
with torch.no_grad(): |
|
predictions = model(images_tensor).logits |
|
if torch.cuda.is_available(): |
|
predictions = predictions.cpu() |
|
predictions = predictions.numpy() |
|
|
|
return predictions |
|
|
|
except Exception as e: |
|
logger.error(f"Error in predict_image: {str(e)}") |
|
return None |
|
finally: |
|
|
|
models.release_vit_model() |
|
|
|
def struck_images(word_images): |
|
try: |
|
predictions = predict_image(word_images) |
|
if predictions is None: |
|
logger.warning("Predictions failed, processing without model") |
|
return word_images |
|
|
|
not_struck = [] |
|
for i in range(len(predictions)): |
|
if predictions[i].argmax() == 0: |
|
not_struck.append(word_images[i]) |
|
|
|
if not not_struck: |
|
logger.warning("No non-struck images found, returning all images") |
|
return word_images |
|
|
|
return not_struck |
|
|
|
except Exception as e: |
|
logger.error(f"Error in struck_images: {str(e)}") |
|
return word_images |
|
|