File size: 2,674 Bytes
231b872
51c49bc
 
 
 
44fb620
dbe0dd0
44fb620
6139662
 
 
2ee994f
 
 
 
 
 
51c49bc
26f855a
 
 
 
2795ce6
 
 
 
 
 
26f855a
44fb620
26f855a
2ee994f
44fb620
51c49bc
6139662
26f855a
6139662
 
 
3885e21
 
 
 
 
 
 
 
 
 
 
 
 
 
26f855a
44fb620
26f855a
2ee994f
3885e21
6139662
 
 
51c49bc
3885e21
8434b5d
6139662
3885e21
 
 
 
 
2795ce6
3885e21
 
 
 
 
 
 
44fb620
3885e21
26f855a
3885e21
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# import pandas as pd
import numpy as np
import torch
import os
import cv2
from transformers import AutoModelForImageClassification, AutoConfig
import logging
from pathlib import Path
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from all_models import models

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

def image_preprocessing(image):
    try:
        images = []
        for i in image:
            binary_image = i
            binary_image = cv2.resize(binary_image, (224, 224))
            binary_image = cv2.merge([binary_image, binary_image, binary_image])
            binary_image = binary_image/255
            binary_image = torch.from_numpy(binary_image)
            images.append(binary_image)
        return images
        
    except Exception as e:
        logger.error(f"Error in image_preprocessing: {str(e)}")
        return None

def predict_image(images):
    try:
        # Get model instance from singleton
        model, processor = models.get_vit_model()
        
        preprocessed_img = image_preprocessing(images)
        if preprocessed_img is None:
            logger.error("Image preprocessing failed")
            return None
            
        images_tensor = torch.stack(preprocessed_img)
        images_tensor = images_tensor.permute(0, 3, 1, 2)
        
        with torch.no_grad():
            predictions = model(images_tensor).logits
            if torch.cuda.is_available():
                predictions = predictions.cpu()
            predictions = predictions.numpy()
            
        return predictions
        
    except Exception as e:
        logger.error(f"Error in predict_image: {str(e)}")
        return None
    finally:
        # Release model reference
        models.release_vit_model()

def struck_images(word_images):
    try:
        predictions = predict_image(word_images)
        if predictions is None:
            logger.warning("Predictions failed, processing without model")
            return word_images
            
        not_struck = []
        for i in range(len(predictions)):
            if predictions[i].argmax() == 0:  # Assuming 0 is the "not struck" class
                not_struck.append(word_images[i])
                
        if not not_struck:
            logger.warning("No non-struck images found, returning all images")
            return word_images
            
        return not_struck
        
    except Exception as e:
        logger.error(f"Error in struck_images: {str(e)}")
        return word_images  # Return all images on error