|
import cv2 |
|
import os |
|
import subprocess |
|
from PIL import Image |
|
import easyocr |
|
from spellchecker import SpellChecker |
|
import numpy as np |
|
import webcolors |
|
from collections import Counter |
|
import torch |
|
from transformers import AutoProcessor, Blip2ForConditionalGeneration |
|
import tensorflow as tf |
|
import argparse |
|
import json |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
import time |
|
from utils.json_helpers import NoIndent, CustomEncoder |
|
|
|
|
|
|
|
BARRIER = "********\n" |
|
|
|
|
|
def is_model_downloaded(model_name, cache_directory): |
|
model_path = os.path.join(cache_directory, model_name.replace('/', '_')) |
|
return os.path.exists(model_path) |
|
|
|
|
|
def closest_colour(requested_colour): |
|
min_colours = {} |
|
css3_names = webcolors.names("css3") |
|
for name in css3_names: |
|
hex_value = webcolors.name_to_hex(name, spec='css3') |
|
r_c, g_c, b_c = webcolors.hex_to_rgb(hex_value) |
|
rd = (r_c - requested_colour[0]) ** 2 |
|
gd = (g_c - requested_colour[1]) ** 2 |
|
bd = (b_c - requested_colour[2]) ** 2 |
|
distance = rd + gd + bd |
|
min_colours[distance] = name |
|
return min_colours[min(min_colours.keys())] |
|
|
|
def get_colour_name(requested_colour): |
|
""" |
|
Returns a tuple: (exact_name, closest_name). |
|
If an exact match fails, 'exact_name' is None, use the 'closest_name' fallback. |
|
""" |
|
try: |
|
actual_name = webcolors.rgb_to_name(requested_colour, spec='css3') |
|
closest_name = actual_name |
|
except ValueError: |
|
closest_name = closest_colour(requested_colour) |
|
actual_name = None |
|
return actual_name, closest_name |
|
|
|
def get_most_frequent_color(pixels, bin_size=10): |
|
""" |
|
Returns the most frequent color among the given pixels, |
|
using a binning approach (default bin size=10). |
|
""" |
|
bins = np.arange(0, 257, bin_size) |
|
r_bins = np.digitize(pixels[:, 0], bins) - 1 |
|
g_bins = np.digitize(pixels[:, 1], bins) - 1 |
|
b_bins = np.digitize(pixels[:, 2], bins) - 1 |
|
combined_bins = r_bins * 10000 + g_bins * 100 + b_bins |
|
bin_counts = Counter(combined_bins) |
|
most_common_bin = bin_counts.most_common(1)[0][0] |
|
|
|
r_bin = most_common_bin // 10000 |
|
g_bin = (most_common_bin % 10000) // 100 |
|
b_bin = most_common_bin % 100 |
|
r_value = bins[r_bin] + bin_size // 2 |
|
g_value = bins[g_bin] + bin_size // 2 |
|
b_value = bins[b_bin] + bin_size // 2 |
|
|
|
return (r_value, g_value, b_value) |
|
|
|
def get_most_frequent_alpha(alphas, bin_size=10): |
|
bins = np.arange(0, 257, bin_size) |
|
alpha_bins = np.digitize(alphas, bins) - 1 |
|
bin_counts = Counter(alpha_bins) |
|
most_common_bin = bin_counts.most_common(1)[0][0] |
|
alpha_value = bins[most_common_bin] + bin_size // 2 |
|
return alpha_value |
|
|
|
|
|
def downscale_for_ocr(image_cv, max_dim=600): |
|
""" |
|
If either dimension of `image_cv` is bigger than `max_dim`, |
|
scale it down proportionally. This speeds up EasyOCR on large images. |
|
""" |
|
h, w = image_cv.shape[:2] |
|
if w <= max_dim and h <= max_dim: |
|
return image_cv |
|
|
|
scale = min(max_dim / float(w), max_dim / float(h)) |
|
new_w = int(w * scale) |
|
new_h = int(h * scale) |
|
image_resized = cv2.resize(image_cv, (new_w, new_h), interpolation=cv2.INTER_AREA) |
|
return image_resized |
|
|
|
|
|
def process_single_region( |
|
idx, bounding_box, image, sr, reader, spell, icon_model, |
|
processor, model, device, no_captioning, output_json, json_mini, |
|
cropped_imageview_images_dir, base_name, save_images, |
|
model_to_use, log_prefix="", |
|
skip_ocr=False, |
|
skip_spell=False |
|
): |
|
""" |
|
Processes one bounding box (region) |
|
Returns a dict with: |
|
* "region_dict" (for JSON) |
|
* "text_log" (file/captions output) |
|
""" |
|
(x_min, y_min, x_max, y_max, class_id) = bounding_box |
|
class_names = {0: 'View', 1: 'ImageView', 2: 'Text', 3: 'Line'} |
|
class_name = class_names.get(class_id, f'Unknown Class {class_id}') |
|
region_idx = idx + 1 |
|
logs = [] |
|
|
|
x_center = (x_min + x_max) // 2 |
|
y_center = (y_min + y_max) // 2 |
|
width = x_max - x_min |
|
height = y_max - y_min |
|
|
|
def open_and_upscale_image(img_path, cid): |
|
|
|
if cid == 2: |
|
MAX_WIDTH, MAX_HEIGHT = 30, 30 |
|
else: |
|
MAX_WIDTH, MAX_HEIGHT = 10, 10 |
|
|
|
def is_small(w, h): |
|
return w <= MAX_WIDTH and h <= MAX_HEIGHT |
|
|
|
if cid == 0: |
|
pil_image = Image.open(img_path).convert("RGBA") |
|
w, h = pil_image.size |
|
if not is_small(w, h): |
|
logs.append(f"{log_prefix}Skipping upscale for large View (size={w}×{h}).") |
|
return pil_image |
|
|
|
|
|
if sr: |
|
image_cv = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGBA2BGR) |
|
upscaled = sr.upsample(image_cv) |
|
return Image.fromarray(cv2.cvtColor(upscaled, cv2.COLOR_BGR2RGBA)) |
|
else: |
|
return pil_image.resize((w * 4, h * 4), resample=Image.BICUBIC) |
|
else: |
|
|
|
cv_image = cv2.imread(img_path) |
|
if cv_image is None or cv_image.size == 0: |
|
logs.append(f"{log_prefix}Empty image at {img_path}, skipping.") |
|
return None |
|
|
|
h, w = cv_image.shape[:2] |
|
if not is_small(w, h): |
|
logs.append(f"{log_prefix}Skipping upscale for large region (size={w}×{h}).") |
|
return cv_image |
|
|
|
if sr: |
|
return sr.upsample(cv_image) |
|
else: |
|
return cv2.resize(cv_image, (w * 2, h * 2), interpolation=cv2.INTER_CUBIC) |
|
|
|
if json_mini: |
|
simplified_class_name = class_name.lower().replace('imageview', 'image') |
|
new_id = f"{simplified_class_name}_{region_idx}" |
|
mini_region_dict = { |
|
"id": new_id, |
|
"bbox": NoIndent([x_center, y_center, width, height]) |
|
} |
|
|
|
|
|
if class_name == 'Text' and not skip_ocr: |
|
cropped_image_region = image[y_min:y_max, x_min:x_max] |
|
if cropped_image_region.size > 0: |
|
|
|
cropped_path = os.path.join(cropped_imageview_images_dir, f"region_{region_idx}_class_{class_id}.jpg") |
|
cv2.imwrite(cropped_path, cropped_image_region) |
|
|
|
upscaled = open_and_upscale_image(cropped_path, class_id) |
|
if upscaled is not None: |
|
if isinstance(upscaled, Image.Image): |
|
upscaled_cv = cv2.cvtColor(np.array(upscaled), cv2.COLOR_RGBA2BGR) |
|
else: |
|
upscaled_cv = upscaled |
|
|
|
gray = cv2.cvtColor(downscale_for_ocr(upscaled_cv), cv2.COLOR_BGR2GRAY) |
|
text = ' '.join(reader.readtext(gray, detail=0, batch_size=8)).strip() |
|
|
|
if text: |
|
if not skip_spell and spell: |
|
corrected_words = [] |
|
for w in text.split(): |
|
corrected_words.append(spell.correction(w) or w) |
|
mini_region_dict["text"] = " ".join(corrected_words) |
|
else: |
|
mini_region_dict["text"] = text |
|
|
|
|
|
if os.path.exists(cropped_path) and not save_images: |
|
os.remove(cropped_path) |
|
|
|
return {"mini_region_dict": mini_region_dict, "text_log": ""} |
|
|
|
logs.append(f"\n{log_prefix}Region {region_idx} - Class ID: {class_id} ({class_name})") |
|
x_center = (x_min + x_max) // 2 |
|
y_center = (y_min + y_max) // 2 |
|
logs.append(f"{log_prefix}Coordinates: x_center={x_center}, y_center={y_center}") |
|
width = x_max - x_min |
|
height = y_max - y_min |
|
logs.append(f"{log_prefix}Size: width={width}, height={height}") |
|
|
|
region_dict = { |
|
"id": f"region_{region_idx}_class_{class_name}", |
|
"x_coordinates_center": x_center, |
|
"y_coordinates_center": y_center, |
|
"width": width, |
|
"height": height |
|
} |
|
|
|
|
|
cropped_image_region = image[y_min:y_max, x_min:x_max] |
|
if cropped_image_region.size == 0: |
|
logs.append(f"{log_prefix}Empty crop for Region {region_idx}, skipping...") |
|
return {"region_dict": region_dict, "text_log": "\n".join(logs)} |
|
|
|
|
|
if class_id == 0: |
|
|
|
cropped_path = os.path.join( |
|
cropped_imageview_images_dir, f"region_{region_idx}_class_{class_id}.png" |
|
) |
|
cv2.imwrite(cropped_path, cropped_image_region) |
|
else: |
|
|
|
cropped_path = os.path.join( |
|
cropped_imageview_images_dir, f"region_{region_idx}_class_{class_id}.jpg" |
|
) |
|
cv2.imwrite(cropped_path, cropped_image_region) |
|
|
|
|
|
def call_ollama(prompt_text, rid, task_type): |
|
model_name = "llama3.2-vision:11b" |
|
cmd = ["ollama", "run", model_name, prompt_text] |
|
try: |
|
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) |
|
if result.returncode != 0: |
|
logs.append(f"{log_prefix}Error generating {task_type} for Region {rid}: {result.stderr}") |
|
return None |
|
else: |
|
response = result.stdout.strip() |
|
logs.append(f"{log_prefix}Generated {task_type.capitalize()} for Region {rid}: {response}") |
|
return response |
|
except Exception as e: |
|
logs.append(f"{log_prefix}An error occurred while generating {task_type} for Region {rid}: {e}") |
|
return None |
|
|
|
|
|
def generate_caption_blip(img_path): |
|
pil_image = Image.open(img_path).convert('RGB') |
|
inputs = processor(images=pil_image, return_tensors="pt").to(device, torch.float16) |
|
gen_ids = model.generate(**inputs) |
|
return processor.batch_decode(gen_ids, skip_special_tokens=True)[0].strip() |
|
|
|
|
|
if class_id == 1: |
|
if no_captioning: |
|
logs.append(f"{log_prefix}(Icon-image detection + captioning disabled by --no-captioning.)") |
|
if not output_json: |
|
block = ( |
|
f"Image: region_{region_idx}_class_{class_id} ({class_name})\n" |
|
f"Coordinates: x_center={(x_min + x_max) // 2}, y_center={(y_min + y_max) // 2}\n" |
|
f"Size: width={width}, height={height}\n" |
|
f"{BARRIER}" |
|
) |
|
logs.append(block) |
|
else: |
|
upscaled = open_and_upscale_image(cropped_path, class_id) |
|
if upscaled is None: |
|
return {"region_dict": region_dict, "text_log": "\n".join(logs)} |
|
|
|
|
|
if icon_model: |
|
icon_input_size = (224, 224) |
|
if isinstance(upscaled, Image.Image): |
|
upscaled_cv = cv2.cvtColor(np.array(upscaled), cv2.COLOR_RGBA2BGR) |
|
else: |
|
upscaled_cv = upscaled |
|
resized = cv2.resize(upscaled_cv, icon_input_size) |
|
rgb_img = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB) / 255.0 |
|
rgb_img = np.expand_dims(rgb_img, axis=0) |
|
pred = icon_model.predict(rgb_img) |
|
logs.append(f"{log_prefix}Prediction output for Region {region_idx}: {pred}") |
|
if pred.shape == (1, 1): |
|
probability = pred[0][0] |
|
threshold = 0.5 |
|
predicted_class = 1 if probability >= threshold else 0 |
|
logs.append(f"{log_prefix}Probability of class 1: {probability}") |
|
elif pred.shape == (1, 2): |
|
predicted_class = np.argmax(pred[0]) |
|
logs.append(f"{log_prefix}Class probabilities: {pred[0]}") |
|
else: |
|
logs.append(f"{log_prefix}Unexpected prediction shape: {pred.shape}") |
|
return {"region_dict": region_dict, "text_log": "\n".join(logs)} |
|
|
|
pred_text = "Icon/Mobile UI Element" if predicted_class == 1 else "Normal Image" |
|
region_dict["prediction"] = pred_text |
|
if predicted_class == 1: |
|
prompt_text = "Describe the mobile UI element on this image. Keep it short." |
|
else: |
|
prompt_text = "Describe what is in the image briefly. It's not an icon or typical UI element." |
|
else: |
|
logs.append(f"{log_prefix}Icon detection model not provided; skipping icon detection.") |
|
region_dict["prediction"] = "Icon detection skipped" |
|
prompt_text = "Describe what is in this image briefly." |
|
|
|
|
|
temp_image_path = os.path.abspath( |
|
os.path.join(cropped_imageview_images_dir, f"imageview_{region_idx}.jpg") |
|
) |
|
if isinstance(upscaled, Image.Image): |
|
upscaled.save(temp_image_path) |
|
else: |
|
cv2.imwrite(temp_image_path, upscaled) |
|
|
|
response = "" |
|
if model and processor and model_to_use == 'blip': |
|
response = generate_caption_blip(temp_image_path) |
|
else: |
|
resp = call_ollama(prompt_text + " " + temp_image_path, region_idx, "description") |
|
response = resp if resp else "Error generating description" |
|
|
|
region_dict["description"] = response |
|
|
|
if not output_json: |
|
block = ( |
|
f"Image: region_{region_idx}_class_{class_id} ({class_name})\n" |
|
f"Coordinates: x_center={(x_min + x_max) // 2}, y_center={(y_min + y_max) // 2}\n" |
|
f"Size: width={width}, height={height}\n" |
|
f"Prediction: {region_dict['prediction']}\n" |
|
f"{response}\n" |
|
f"{BARRIER}" |
|
) |
|
logs.append(block) |
|
|
|
if os.path.exists(temp_image_path) and not save_images: |
|
os.remove(temp_image_path) |
|
|
|
elif class_id == 2: |
|
if skip_ocr or reader is None: |
|
logs.append(f"{log_prefix}OCR skipped for Region {region_idx}.") |
|
|
|
if not output_json: |
|
block = ( |
|
f"Text: region_{region_idx}_class_{class_id} ({class_name})\n" |
|
f"Coordinates: x_center={(x_min + x_max) // 2}, " |
|
f"y_center={(y_min + y_max) // 2}\n" |
|
f"Size: width={width}, height={height}\n" |
|
f"OCR + spell-check disabled\n" |
|
f"{BARRIER}" |
|
) |
|
logs.append(block) |
|
|
|
return {"region_dict": region_dict, "text_log": "\n".join(logs)} |
|
|
|
upscaled = open_and_upscale_image(cropped_path, class_id) |
|
if upscaled is None: |
|
return {"region_dict": region_dict, "text_log": "\n".join(logs)} |
|
|
|
if isinstance(upscaled, Image.Image): |
|
upscaled_cv = cv2.cvtColor(np.array(upscaled), cv2.COLOR_RGBA2BGR) |
|
else: |
|
upscaled_cv = upscaled |
|
|
|
|
|
|
|
upscaled_cv = downscale_for_ocr(upscaled_cv, max_dim=600) |
|
gray = cv2.cvtColor(upscaled_cv, cv2.COLOR_BGR2GRAY) |
|
result_ocr = reader.readtext(gray, detail=0, batch_size=8) |
|
text = ' '.join(result_ocr).strip() |
|
|
|
|
|
if skip_spell or spell is None: |
|
corrected_text = None |
|
logs.append(f"{log_prefix}Spell-check skipped for Region {region_idx}.") |
|
else: |
|
correction_cache = {} |
|
corrected_words = [] |
|
for w in text.split(): |
|
if w not in correction_cache: |
|
correction_cache[w] = spell.correction(w) or w |
|
corrected_words.append(correction_cache[w]) |
|
corrected_text = " ".join(corrected_words) |
|
|
|
|
|
logs.append(f"{log_prefix}Extracted Text for Region {region_idx}: {text}") |
|
if corrected_text is not None: |
|
logs.append(f"{log_prefix}Corrected Text for Region {region_idx}: {corrected_text}") |
|
|
|
region_dict["extractedText"] = text |
|
if corrected_text is not None: |
|
region_dict["correctedText"] = corrected_text |
|
|
|
if not output_json: |
|
block = ( |
|
f"Text: region_{region_idx}_class_{class_id} ({class_name})\n" |
|
f"Coordinates: x_center={(x_min + x_max) // 2}, y_center={(y_min + y_max) // 2}\n" |
|
f"Size: width={width}, height={height}\n" |
|
f"Extracted Text: {text}\n" |
|
+ (f"Corrected Text: {corrected_text}\n" if corrected_text is not None else "") |
|
+ f"{BARRIER}" |
|
) |
|
logs.append(block) |
|
|
|
elif class_id == 0: |
|
upscaled = open_and_upscale_image(cropped_path, class_id) |
|
if upscaled is None: |
|
return {"region_dict": region_dict, "text_log": "\n".join(logs)} |
|
|
|
data = np.array(upscaled) |
|
if data.ndim == 2: |
|
data = cv2.cvtColor(data, cv2.COLOR_GRAY2BGRA) |
|
elif data.shape[-1] == 3: |
|
b, g, r = cv2.split(data) |
|
a = np.full_like(b, 255) |
|
data = cv2.merge((b, g, r, a)) |
|
|
|
pixels = data.reshape((-1, 4)) |
|
opaque_pixels = pixels[pixels[:, 3] > 0] |
|
|
|
if len(opaque_pixels) == 0: |
|
logs.append(f"{log_prefix}No opaque pixels found in Region {region_idx}, cannot determine background color.") |
|
color_name = "Unknown" |
|
else: |
|
dom_color = get_most_frequent_color(opaque_pixels[:, :3], bin_size=10) |
|
exact_name, closest_name = get_colour_name(dom_color) |
|
color_name = exact_name if exact_name else closest_name |
|
|
|
alphas = pixels[:, 3] |
|
dominant_alpha = get_most_frequent_alpha(alphas, bin_size=10) |
|
transparency = "opaque" if dominant_alpha >= 245 else "transparent" |
|
|
|
response = ( |
|
f"1. The background color of the container is {color_name}.\n" |
|
f"2. The container is {transparency}." |
|
) |
|
logs.append(f"{log_prefix}{response}") |
|
region_dict["view_color"] = f"The background color of the container is {color_name}." |
|
region_dict["view_alpha"] = f"The container is {transparency}." |
|
|
|
if not output_json: |
|
block = ( |
|
f"View: region_{region_idx}_class_{class_id} ({class_name})\n" |
|
f"Coordinates: x_center={(x_min + x_max) // 2}, y_center={(y_min + y_max) // 2}\n" |
|
f"Size: width={width}, height={height}\n" |
|
f"{response}\n" |
|
f"{BARRIER}" |
|
) |
|
logs.append(block) |
|
|
|
elif class_id == 3: |
|
logs.append(f"{log_prefix}Processing Line in Region {region_idx}") |
|
line_img = cv2.imread(cropped_path, cv2.IMREAD_UNCHANGED) |
|
if line_img is None: |
|
logs.append(f"{log_prefix}Failed to read image at {cropped_path}") |
|
return {"region_dict": region_dict, "text_log": "\n".join(logs)} |
|
|
|
hh, ww = line_img.shape[:2] |
|
logs.append(f"{log_prefix}Image dimensions: width={ww}, height={hh}") |
|
|
|
data = np.array(line_img) |
|
if data.ndim == 2: |
|
data = cv2.cvtColor(data, cv2.COLOR_GRAY2BGRA) |
|
elif data.shape[-1] == 3: |
|
b, g, r = cv2.split(data) |
|
a = np.full_like(b, 255) |
|
data = cv2.merge((b, g, r, a)) |
|
|
|
pixels = data.reshape((-1, 4)) |
|
opaque_pixels = pixels[pixels[:, 3] > 0] |
|
|
|
if len(opaque_pixels) == 0: |
|
logs.append(f"{log_prefix}No opaque pixels found in Region {region_idx}, cannot determine line color.") |
|
color_name = "Unknown" |
|
else: |
|
dom_color = get_most_frequent_color(opaque_pixels[:, :3], bin_size=10) |
|
exact_name, closest_name = get_colour_name(dom_color) |
|
color_name = exact_name if exact_name else closest_name |
|
|
|
alphas = pixels[:, 3] |
|
dom_alpha = get_most_frequent_alpha(alphas, bin_size=10) |
|
transparency = "opaque" if dom_alpha >= 245 else "transparent" |
|
|
|
response = ( |
|
f"1. The color of the line is {color_name}.\n" |
|
f"2. The line is {transparency}." |
|
) |
|
logs.append(f"{log_prefix}{response}") |
|
region_dict["line_color"] = f"The color of the line is {color_name}." |
|
region_dict["line_alpha"] = f"The line is {transparency}." |
|
|
|
if not output_json: |
|
block = ( |
|
f"Line: region_{region_idx}_class_{class_id} ({class_name})\n" |
|
f"Coordinates: x_center={(x_min + x_max) // 2}, y_center={(y_min + y_max) // 2}\n" |
|
f"Size: width={width}, height={height}\n" |
|
f"{response}\n" |
|
f"{BARRIER}" |
|
) |
|
logs.append(block) |
|
|
|
else: |
|
logs.append(f"{log_prefix}Class ID {class_id} not handled.") |
|
|
|
|
|
if os.path.exists(cropped_path) and not save_images: |
|
os.remove(cropped_path) |
|
|
|
return { |
|
"region_dict": region_dict, |
|
"text_log": "\n".join(logs), |
|
} |
|
|
|
|
|
|
|
def process_image( |
|
input_image_path, |
|
yolo_output_path, |
|
output_dir:str = '.', |
|
model_to_use='llama', |
|
save_images=False, |
|
icon_model_path=None, |
|
cache_directory='./models_cache', |
|
huggingface_token='your_token', |
|
no_captioning=False, |
|
output_json=False, |
|
json_mini=False, |
|
sr=None, |
|
reader=None, |
|
spell=None, |
|
skip_ocr=False, |
|
skip_spell=False |
|
): |
|
if json_mini: |
|
json_output = { |
|
"image_size": None, |
|
"bbox_format": "center_x, center_y, width, height", |
|
"elements": [] |
|
} |
|
elif output_json: |
|
json_output = { |
|
"image": {"path": input_image_path, "size": {"width": None, "height": None}}, |
|
"elements": [] |
|
} |
|
else: |
|
json_output = None |
|
|
|
|
|
start_time = time.perf_counter() |
|
print("super-resolution initialization start (in script.py)") |
|
|
|
if sr is None: |
|
print("No sr reference passed; performing local init ...") |
|
model_path = 'EDSR_x4.pb' |
|
if hasattr(cv2, 'dnn_superres'): |
|
print("dnn_superres module is available.") |
|
import cv2.dnn_superres as dnn_superres |
|
try: |
|
sr = cv2.dnn_superres.DnnSuperResImpl_create() |
|
print("Using DnnSuperResImpl_create()") |
|
except AttributeError: |
|
sr = cv2.dnn_superres.DnnSuperResImpl() |
|
print("Using DnnSuperResImpl()") |
|
sr.readModel(model_path) |
|
sr.setModel('edsr', 4) |
|
else: |
|
print("dnn_superres module is NOT available; skipping super-resolution.") |
|
else: |
|
print("Using pre-initialized sr reference.") |
|
|
|
|
|
elapsed = time.perf_counter() - start_time |
|
print(f"super-resoulution init (in script.py) took {elapsed:.3f} seconds.") |
|
|
|
start_time = time.perf_counter() |
|
|
|
if skip_ocr: |
|
print("skip_ocr flag set - skipping EasyOCR and SpellChecker.") |
|
reader = None |
|
spell = None |
|
|
|
else: |
|
print("OCR initialisation start (in script.py)") |
|
if reader is None: |
|
print("No EasyOCR reference passed; performing local init") |
|
reader = easyocr.Reader(['en'], gpu=True) |
|
else: |
|
print("Using pre-initialised EasyOCR object.") |
|
|
|
if skip_spell: |
|
print("skip_spell flag set - not initialising SpellChecker.") |
|
spell = None |
|
else: |
|
if spell is None: |
|
print("No SpellChecker reference passed; performing local init") |
|
spell = SpellChecker() |
|
else: |
|
print("Using pre-initialised SpellChecker object.") |
|
|
|
elapsed = time.perf_counter() - start_time |
|
print(f"OCR init (in script.py) took {elapsed:.3f} seconds.") |
|
|
|
|
|
start_time = time.perf_counter() |
|
print("icon-model init start (in script.py)") |
|
|
|
if icon_model_path: |
|
icon_model = tf.keras.models.load_model(icon_model_path) |
|
print(f"Icon detection model loaded: {icon_model_path}") |
|
else: |
|
icon_model = None |
|
|
|
elapsed = time.perf_counter() - start_time |
|
print(f"icon-model init (in script.py) took {elapsed:.3f} seconds.") |
|
|
|
|
|
|
|
image = cv2.imread(input_image_path, cv2.IMREAD_UNCHANGED) |
|
if image is None: |
|
print(f"Image at {input_image_path} could not be loaded.") |
|
return |
|
|
|
image_height, image_width = image.shape[:2] |
|
|
|
|
|
|
|
with open(yolo_output_path, 'r') as f: |
|
lines = f.readlines() |
|
|
|
|
|
if torch.backends.mps.is_available(): |
|
device = torch.device("mps") |
|
print("Using MPS") |
|
elif torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
print("Using CUDA") |
|
else: |
|
device = torch.device("cpu") |
|
print("Using CPU") |
|
|
|
|
|
processor, model = None, None |
|
if not no_captioning: |
|
if model_to_use == 'blip': |
|
print("Loading BLIP-2 model...") |
|
blip_model_name = "Salesforce/blip2-opt-2.7b" |
|
if not is_model_downloaded(blip_model_name, cache_directory): |
|
print("Model not found in cache. Downloading...") |
|
else: |
|
print("Model found in cache. Loading...") |
|
processor = AutoProcessor.from_pretrained( |
|
blip_model_name, |
|
use_auth_token=huggingface_token, |
|
cache_dir=cache_directory, |
|
resume_download=True |
|
) |
|
model = Blip2ForConditionalGeneration.from_pretrained( |
|
blip_model_name, |
|
device_map='auto', |
|
torch_dtype=torch.float16, |
|
use_auth_token=huggingface_token, |
|
cache_dir=cache_directory, |
|
resume_download=True |
|
).to(device) |
|
else: |
|
print("Using LLaMA model via external call (ollama).") |
|
else: |
|
print("--no-captioning flag is set; skipping model loading.") |
|
|
|
|
|
bounding_boxes = [] |
|
for line in lines: |
|
parts = line.strip().split() |
|
class_id = int(parts[0]) |
|
x_center_norm, y_center_norm, width_norm, height_norm = map(float, parts[1:]) |
|
x_center = x_center_norm * image_width |
|
y_center = y_center_norm * image_height |
|
box_width = width_norm * image_width |
|
box_height = height_norm * image_height |
|
x_min = int(x_center - box_width / 2) |
|
y_min = int(y_center - box_height / 2) |
|
x_max = int(x_center + box_width / 2) |
|
y_max = int(y_center + box_height / 2) |
|
x_min = max(0, x_min) |
|
y_min = max(0, y_min) |
|
x_max = min(image_width - 1, x_max) |
|
y_max = min(image_height - 1, y_max) |
|
bounding_boxes.append((x_min, y_min, x_max, y_max, class_id)) |
|
|
|
|
|
|
|
cropped_dir = os.path.join(output_dir, "cropped_imageview_images") |
|
os.makedirs(cropped_dir, exist_ok=True) |
|
result_dir = os.path.join(output_dir, "result") |
|
os.makedirs(result_dir, exist_ok=True) |
|
|
|
base_name = os.path.splitext(os.path.basename(input_image_path))[0] |
|
captions_file_path = None |
|
if json_mini: |
|
json_output["image_size"] = NoIndent([image_width, image_height]) |
|
elif output_json: |
|
json_output["image"]["size"]["width"] = image_width |
|
json_output["image"]["size"]["height"] = image_height |
|
else: |
|
captions_filename = f"{base_name}_regions_captions.txt" |
|
captions_file_path = os.path.join(result_dir, captions_filename) |
|
with open(captions_file_path, 'w', encoding='utf-8') as f: |
|
f.write(f"Image path: {input_image_path}\n") |
|
f.write(f"Image Size: width={image_width}, height={image_height}\n") |
|
f.write(BARRIER) |
|
|
|
|
|
start_time = time.perf_counter() |
|
print("Process single region start (in script.py)") |
|
|
|
with ThreadPoolExecutor(max_workers=1) as executor: |
|
futures = [ |
|
executor.submit( |
|
process_single_region, |
|
idx, box, image, sr, reader, spell, |
|
icon_model, processor, model, (model and device), |
|
no_captioning, output_json, json_mini, |
|
cropped_dir, base_name, save_images, |
|
model_to_use, log_prefix="", |
|
skip_ocr=skip_ocr, |
|
skip_spell=skip_spell |
|
) for idx, box in enumerate(bounding_boxes) |
|
] |
|
|
|
for future in as_completed(futures): |
|
item = future.result() |
|
if json_mini: |
|
if item.get("mini_region_dict"): |
|
json_output["elements"].append(item["mini_region_dict"]) |
|
elif output_json: |
|
if item.get("region_dict"): |
|
json_output["elements"].append(item["region_dict"]) |
|
else: |
|
if item.get("text_log") and captions_file_path: |
|
with open(captions_file_path, 'a', encoding='utf-8') as f: |
|
f.write(item["text_log"]) |
|
|
|
elapsed = time.perf_counter() - start_time |
|
print(f"Processing regions took {elapsed:.3f} seconds.") |
|
|
|
if json_mini or output_json: |
|
json_file = os.path.join(result_dir, f"{base_name}.json") |
|
with open(json_file, 'w', encoding='utf-8') as f: |
|
json.dump(json_output, f, indent=2, ensure_ascii=False, cls=CustomEncoder) |
|
|
|
output_type = "mini JSON" if json_mini else "JSON" |
|
print(f"{output_type} output written to {json_file}") |
|
else: |
|
print(f"Text output written to {captions_file_path}") |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description='Process an image and its YOLO labels.') |
|
parser.add_argument('input_image', help='Path to the input YOLO image.') |
|
parser.add_argument('input_labels', help='Path to the input YOLO labels file.') |
|
parser.add_argument('--output_dir', default='.', |
|
help='Directory to save output files. Defaults to the current directory.') |
|
parser.add_argument('--model_to_use', choices=['llama', 'blip'], default='llama', |
|
help='Model for captioning (llama or blip).') |
|
parser.add_argument('--save_images', action='store_true', |
|
help='Flag to save intermediate images.') |
|
parser.add_argument('--icon_detection_path', help='Path to icon detection model.') |
|
parser.add_argument('--cache_directory', default='./models_cache', |
|
help='Cache directory for Hugging Face models.') |
|
parser.add_argument('--huggingface_token', default='your_token', |
|
help='Hugging Face token for model downloads.') |
|
parser.add_argument('--no-captioning', action='store_true', |
|
help='Disable any image captioning.') |
|
parser.add_argument('--json', dest='output_json', action='store_true', |
|
help='Output the image data in JSON format') |
|
parser.add_argument('--json-mini', action='store_true', |
|
help='Output the image data in a condensed JSON format') |
|
args = parser.parse_args() |
|
|
|
process_image( |
|
input_image_path=args.input_image, |
|
yolo_output_path=args.input_labels, |
|
output_dir=args.output_dir, |
|
model_to_use=args.model_to_use, |
|
save_images=args.save_images, |
|
icon_model_path=args.icon_detection_path, |
|
cache_directory=args.cache_directory, |
|
huggingface_token=args.huggingface_token, |
|
no_captioning=args.no_captioning, |
|
output_json=args.output_json, |
|
json_mini=args.json_mini |
|
) |
|
|
|
|