yamanavijayavardhan's picture
printing extracted text18
6139662
# import pandas as pd
import numpy as np
import torch
import os
import cv2
from transformers import AutoModelForImageClassification, AutoConfig
import logging
from pathlib import Path
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from all_models import models
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def image_preprocessing(image):
try:
images = []
for i in image:
binary_image = i
binary_image = cv2.resize(binary_image, (224, 224))
binary_image = cv2.merge([binary_image, binary_image, binary_image])
binary_image = binary_image/255
binary_image = torch.from_numpy(binary_image)
images.append(binary_image)
return images
except Exception as e:
logger.error(f"Error in image_preprocessing: {str(e)}")
return None
def predict_image(images):
try:
# Get model instance from singleton
model, processor = models.get_vit_model()
preprocessed_img = image_preprocessing(images)
if preprocessed_img is None:
logger.error("Image preprocessing failed")
return None
images_tensor = torch.stack(preprocessed_img)
images_tensor = images_tensor.permute(0, 3, 1, 2)
with torch.no_grad():
predictions = model(images_tensor).logits
if torch.cuda.is_available():
predictions = predictions.cpu()
predictions = predictions.numpy()
return predictions
except Exception as e:
logger.error(f"Error in predict_image: {str(e)}")
return None
finally:
# Release model reference
models.release_vit_model()
def struck_images(word_images):
try:
predictions = predict_image(word_images)
if predictions is None:
logger.warning("Predictions failed, processing without model")
return word_images
not_struck = []
for i in range(len(predictions)):
if predictions[i].argmax() == 0: # Assuming 0 is the "not struck" class
not_struck.append(word_images[i])
if not not_struck:
logger.warning("No non-struck images found, returning all images")
return word_images
return not_struck
except Exception as e:
logger.error(f"Error in struck_images: {str(e)}")
return word_images # Return all images on error