yamanavijayavardhan commited on
Commit
44fb620
·
1 Parent(s): 8dd6f8c

update_new_new_new_new_new

Browse files
Files changed (1) hide show
  1. HTR/strike.py +120 -25
HTR/strike.py CHANGED
@@ -3,8 +3,9 @@ import numpy as np
3
  import torch
4
  import os
5
  import cv2
6
- from transformers import AutoModelForImageClassification
7
  import logging
 
8
 
9
  logging.basicConfig(
10
  level=logging.INFO,
@@ -21,38 +22,93 @@ def initialize_model():
21
  if model is None:
22
  try:
23
  logger.info("Initializing model...")
24
- model = AutoModelForImageClassification.from_pretrained("models/vit-base-beans")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  if torch.cuda.is_available():
26
  model = model.to('cuda')
27
  logger.info("Model moved to CUDA")
 
 
 
 
28
  logger.info("Model initialized successfully")
 
 
29
  except Exception as e:
30
  logger.error(f"Error initializing model: {str(e)}")
31
- raise
32
 
33
  def image_preprocessing(image):
34
  try:
35
  images = []
36
  for i in image:
37
- binary_image = i
38
- binary_image = cv2.resize(binary_image, (224, 224))
39
- binary_image = cv2.merge([binary_image, binary_image, binary_image])
40
- binary_image = binary_image/255
41
- binary_image = torch.from_numpy(binary_image)
42
- images.append(binary_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  return images
 
44
  except Exception as e:
45
  logger.error(f"Error in image_preprocessing: {str(e)}")
46
- return []
47
 
48
- def predict_image(image_path, model):
49
  try:
50
- preprocessed_img = image_preprocessing(image_path)
51
- if not preprocessed_img:
52
- return None
 
 
 
 
 
53
 
54
- images = torch.stack(preprocessed_img)
55
- images = images.permute(0, 3, 1, 2)
56
 
57
  if torch.cuda.is_available():
58
  images = images.to('cuda')
@@ -60,9 +116,38 @@ def predict_image(image_path, model):
60
  with torch.no_grad():
61
  predictions = model(images).logits.detach().cpu().numpy()
62
  return predictions
 
63
  except Exception as e:
64
  logger.error(f"Error in predict_image: {str(e)}")
65
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def struck_images(image_paths):
68
  try:
@@ -73,6 +158,9 @@ def struck_images(image_paths):
73
  logger.info(f"Processing {len(image_paths)} images")
74
  processed_paths = []
75
 
 
 
 
76
  for i, img_path in enumerate(image_paths):
77
  try:
78
  # Read the image from the path
@@ -81,12 +169,6 @@ def struck_images(image_paths):
81
  logger.error(f"Failed to read image: {img_path}")
82
  continue
83
 
84
- # Resize if image is too small
85
- min_size = 800
86
- if img.shape[0] < min_size or img.shape[1] < min_size:
87
- scale = min_size / min(img.shape[0], img.shape[1])
88
- img = cv2.resize(img, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
89
-
90
  # Process the image
91
  processed = process_single_image(img)
92
  if processed is None:
@@ -101,8 +183,21 @@ def struck_images(image_paths):
101
  logger.error(f"Error processing image {img_path}: {str(e)}")
102
  continue
103
 
104
- logger.info(f"Successfully processed {len(processed_paths)} images")
105
- return processed_paths
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  except Exception as e:
108
  logger.error(f"Error in struck_images: {str(e)}")
 
3
  import torch
4
  import os
5
  import cv2
6
+ from transformers import AutoModelForImageClassification, AutoConfig
7
  import logging
8
+ from pathlib import Path
9
 
10
  logging.basicConfig(
11
  level=logging.INFO,
 
22
  if model is None:
23
  try:
24
  logger.info("Initializing model...")
25
+ # Use model directly from Hugging Face hub
26
+ model_name = "microsoft/resnet-50" # Using a more general model for classification
27
+
28
+ try:
29
+ # First try to load from cache
30
+ cache_dir = os.path.join(os.environ.get('TMPDIR', '/tmp'), 'model_cache')
31
+ os.makedirs(cache_dir, exist_ok=True)
32
+
33
+ config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
34
+ model = AutoModelForImageClassification.from_pretrained(
35
+ model_name,
36
+ config=config,
37
+ cache_dir=cache_dir
38
+ )
39
+ logger.info(f"Model loaded from {model_name}")
40
+ except Exception as e:
41
+ logger.error(f"Error loading model from hub: {str(e)}")
42
+ # Fallback to simpler processing if model fails to load
43
+ return None
44
+
45
  if torch.cuda.is_available():
46
  model = model.to('cuda')
47
  logger.info("Model moved to CUDA")
48
+ else:
49
+ logger.info("Running on CPU")
50
+
51
+ model.eval() # Set to evaluation mode
52
  logger.info("Model initialized successfully")
53
+ return model
54
+
55
  except Exception as e:
56
  logger.error(f"Error initializing model: {str(e)}")
57
+ return None
58
 
59
  def image_preprocessing(image):
60
  try:
61
  images = []
62
  for i in image:
63
+ try:
64
+ # Ensure image is in correct format
65
+ if isinstance(i, str):
66
+ # If i is a path, read the image
67
+ i = cv2.imread(i, cv2.IMREAD_GRAYSCALE)
68
+ if i is None:
69
+ logger.error("Failed to read image from path")
70
+ continue
71
+
72
+ # Resize to model input size
73
+ binary_image = cv2.resize(i, (224, 224))
74
+
75
+ # Convert to RGB (3 channels)
76
+ binary_image = cv2.cvtColor(binary_image, cv2.COLOR_GRAY2RGB)
77
+
78
+ # Normalize
79
+ binary_image = binary_image.astype(np.float32) / 255.0
80
+
81
+ # Convert to tensor
82
+ binary_image = torch.from_numpy(binary_image)
83
+ binary_image = binary_image.permute(2, 0, 1) # Change to CxHxW format
84
+ images.append(binary_image)
85
+
86
+ except Exception as e:
87
+ logger.error(f"Error preprocessing individual image: {str(e)}")
88
+ continue
89
+
90
+ if not images:
91
+ logger.error("No images were successfully preprocessed")
92
+ return None
93
+
94
  return images
95
+
96
  except Exception as e:
97
  logger.error(f"Error in image_preprocessing: {str(e)}")
98
+ return None
99
 
100
+ def predict_image(image_paths, model):
101
  try:
102
+ if model is None:
103
+ logger.warning("Model not initialized, using basic processing")
104
+ return process_without_model(image_paths)
105
+
106
+ preprocessed_imgs = image_preprocessing(image_paths)
107
+ if not preprocessed_imgs:
108
+ logger.warning("No preprocessed images, using basic processing")
109
+ return process_without_model(image_paths)
110
 
111
+ images = torch.stack(preprocessed_imgs)
 
112
 
113
  if torch.cuda.is_available():
114
  images = images.to('cuda')
 
116
  with torch.no_grad():
117
  predictions = model(images).logits.detach().cpu().numpy()
118
  return predictions
119
+
120
  except Exception as e:
121
  logger.error(f"Error in predict_image: {str(e)}")
122
+ return process_without_model(image_paths)
123
+
124
+ def process_without_model(image_paths):
125
+ """Fallback processing when model is not available"""
126
+ try:
127
+ results = []
128
+ for path in image_paths:
129
+ # Basic image processing to detect if image is struck through
130
+ img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
131
+ if img is None:
132
+ continue
133
+
134
+ # Use basic image processing to detect strike-through
135
+ # This is a simplified approach
136
+ thresh = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
137
+ horizontal_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN,
138
+ np.ones((1, 20), np.uint8))
139
+
140
+ # If there are significant horizontal lines, consider it struck
141
+ if np.sum(horizontal_lines) > (img.shape[0] * img.shape[1] * 0.1):
142
+ results.append(1) # Struck
143
+ else:
144
+ results.append(0) # Not struck
145
+
146
+ return np.array(results)
147
+
148
+ except Exception as e:
149
+ logger.error(f"Error in process_without_model: {str(e)}")
150
+ return np.zeros(len(image_paths)) # Return all as not struck
151
 
152
  def struck_images(image_paths):
153
  try:
 
158
  logger.info(f"Processing {len(image_paths)} images")
159
  processed_paths = []
160
 
161
+ # Initialize model
162
+ model = initialize_model()
163
+
164
  for i, img_path in enumerate(image_paths):
165
  try:
166
  # Read the image from the path
 
169
  logger.error(f"Failed to read image: {img_path}")
170
  continue
171
 
 
 
 
 
 
 
172
  # Process the image
173
  processed = process_single_image(img)
174
  if processed is None:
 
183
  logger.error(f"Error processing image {img_path}: {str(e)}")
184
  continue
185
 
186
+ # Get predictions
187
+ predictions = predict_image(processed_paths, model)
188
+
189
+ # Filter based on predictions
190
+ not_struck = []
191
+ for i, pred in enumerate(predictions):
192
+ if isinstance(pred, np.ndarray):
193
+ if pred.argmax() == 0: # Not struck
194
+ not_struck.append(processed_paths[i])
195
+ else:
196
+ if pred == 0: # Not struck
197
+ not_struck.append(processed_paths[i])
198
+
199
+ logger.info(f"Found {len(not_struck)} non-struck images")
200
+ return not_struck
201
 
202
  except Exception as e:
203
  logger.error(f"Error in struck_images: {str(e)}")