File size: 32,639 Bytes
6ff22d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
import cv2
import os
import subprocess
from PIL import Image
import easyocr
from spellchecker import SpellChecker
import numpy as np
import webcolors
from collections import Counter
import torch
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import tensorflow as tf
import argparse
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
from utils.json_helpers import NoIndent, CustomEncoder


# constants
BARRIER = "********\n"

# Check if a model is in the cache
def is_model_downloaded(model_name, cache_directory):
    model_path = os.path.join(cache_directory, model_name.replace('/', '_'))
    return os.path.exists(model_path)

# Convert color to the closest name
def closest_colour(requested_colour):
    min_colours = {}
    css3_names = webcolors.names("css3")
    for name in css3_names:
        hex_value = webcolors.name_to_hex(name, spec='css3')
        r_c, g_c, b_c = webcolors.hex_to_rgb(hex_value)
        rd = (r_c - requested_colour[0]) ** 2
        gd = (g_c - requested_colour[1]) ** 2
        bd = (b_c - requested_colour[2]) ** 2
        distance = rd + gd + bd
        min_colours[distance] = name
    return min_colours[min(min_colours.keys())]

def get_colour_name(requested_colour):
    """
    Returns a tuple: (exact_name, closest_name).
    If an exact match fails, 'exact_name' is None, use the 'closest_name' fallback.
    """
    try:
        actual_name = webcolors.rgb_to_name(requested_colour, spec='css3')
        closest_name = actual_name
    except ValueError:
        closest_name = closest_colour(requested_colour)
        actual_name = None
    return actual_name, closest_name

def get_most_frequent_color(pixels, bin_size=10):
    """
    Returns the most frequent color among the given pixels,
    using a binning approach (default bin size=10).
    """
    bins = np.arange(0, 257, bin_size)
    r_bins = np.digitize(pixels[:, 0], bins) - 1
    g_bins = np.digitize(pixels[:, 1], bins) - 1
    b_bins = np.digitize(pixels[:, 2], bins) - 1
    combined_bins = r_bins * 10000 + g_bins * 100 + b_bins
    bin_counts = Counter(combined_bins)
    most_common_bin = bin_counts.most_common(1)[0][0]

    r_bin = most_common_bin // 10000
    g_bin = (most_common_bin % 10000) // 100
    b_bin = most_common_bin % 100
    r_value = bins[r_bin] + bin_size // 2
    g_value = bins[g_bin] + bin_size // 2
    b_value = bins[b_bin] + bin_size // 2

    return (r_value, g_value, b_value)

def get_most_frequent_alpha(alphas, bin_size=10):
    bins = np.arange(0, 257, bin_size)
    alpha_bins = np.digitize(alphas, bins) - 1
    bin_counts = Counter(alpha_bins)
    most_common_bin = bin_counts.most_common(1)[0][0]
    alpha_value = bins[most_common_bin] + bin_size // 2
    return alpha_value

# downscale images for OCR. TODO change dim to a suitable one
def downscale_for_ocr(image_cv, max_dim=600):
    """
    If either dimension of `image_cv` is bigger than `max_dim`,
    scale it down proportionally. This speeds up EasyOCR on large images.
    """
    h, w = image_cv.shape[:2]
    if w <= max_dim and h <= max_dim:
        return image_cv  # No downscale needed

    scale = min(max_dim / float(w), max_dim / float(h))
    new_w = int(w * scale)
    new_h = int(h * scale)
    image_resized = cv2.resize(image_cv, (new_w, new_h), interpolation=cv2.INTER_AREA)
    return image_resized

# Worker function to process a single bounding box
def process_single_region(
    idx, bounding_box, image, sr, reader, spell, icon_model,
    processor, model, device, no_captioning, output_json, json_mini,
    cropped_imageview_images_dir, base_name, save_images,
    model_to_use, log_prefix="",
    skip_ocr=False,
    skip_spell=False
):
    """
    Processes one bounding box (region)
    Returns a dict with:
      * "region_dict" (for JSON)
      * "text_log" (file/captions output)
    """
    (x_min, y_min, x_max, y_max, class_id) = bounding_box
    class_names = {0: 'View', 1: 'ImageView', 2: 'Text', 3: 'Line'}
    class_name = class_names.get(class_id, f'Unknown Class {class_id}')
    region_idx = idx + 1
    logs = []

    x_center = (x_min + x_max) // 2
    y_center = (y_min + y_max) // 2
    width = x_max - x_min
    height = y_max - y_min

    def open_and_upscale_image(img_path, cid):

        if cid == 2: # Text 
            MAX_WIDTH, MAX_HEIGHT = 30, 30
        else:
            MAX_WIDTH, MAX_HEIGHT = 10, 10

        def is_small(w, h):
            return w <= MAX_WIDTH and h <= MAX_HEIGHT

        if cid == 0:  # "View" - use PIL to preserve alpha
            pil_image = Image.open(img_path).convert("RGBA")
            w, h = pil_image.size
            if not is_small(w, h):
                logs.append(f"{log_prefix}Skipping upscale for large View (size={w}×{h}).")
                return pil_image

            # If super-resolution is provided, use it
            if sr:
                image_cv = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGBA2BGR)
                upscaled = sr.upsample(image_cv)
                return Image.fromarray(cv2.cvtColor(upscaled, cv2.COLOR_BGR2RGBA))
            else:
                return pil_image.resize((w * 4, h * 4), resample=Image.BICUBIC)
        else:
            # For other classes, load the image with OpenCV (BGR)
            cv_image = cv2.imread(img_path)
            if cv_image is None or cv_image.size == 0:
                logs.append(f"{log_prefix}Empty image at {img_path}, skipping.")
                return None

            h, w = cv_image.shape[:2]
            if not is_small(w, h):
                logs.append(f"{log_prefix}Skipping upscale for large region (size={w}×{h}).")
                return cv_image

            if sr:
                return sr.upsample(cv_image)
            else:
                return cv2.resize(cv_image, (w * 2, h * 2), interpolation=cv2.INTER_CUBIC)

    if json_mini:
        simplified_class_name = class_name.lower().replace('imageview', 'image')
        new_id = f"{simplified_class_name}_{region_idx}"
        mini_region_dict = {
            "id": new_id,
            "bbox": NoIndent([x_center, y_center, width, height])
        }
        
        # only need to process text for the mini format
        if class_name == 'Text' and not skip_ocr:
            cropped_image_region = image[y_min:y_max, x_min:x_max]
            if cropped_image_region.size > 0:
                # Save the cropped image so open_and_upscale_image can use it
                cropped_path = os.path.join(cropped_imageview_images_dir, f"region_{region_idx}_class_{class_id}.jpg")
                cv2.imwrite(cropped_path, cropped_image_region)

                upscaled = open_and_upscale_image(cropped_path, class_id)
                if upscaled is not None:
                    if isinstance(upscaled, Image.Image):
                        upscaled_cv = cv2.cvtColor(np.array(upscaled), cv2.COLOR_RGBA2BGR)
                    else:
                        upscaled_cv = upscaled
                    
                    gray = cv2.cvtColor(downscale_for_ocr(upscaled_cv), cv2.COLOR_BGR2GRAY)
                    text = ' '.join(reader.readtext(gray, detail=0, batch_size=8)).strip()

                    if text:
                        if not skip_spell and spell:
                            corrected_words  = []
                            for w in text.split():
                                corrected_words.append(spell.correction(w) or w)
                            mini_region_dict["text"] = " ".join(corrected_words)
                        else:
                             mini_region_dict["text"] = text
                
                # Clean up the temporary cropped image
                if os.path.exists(cropped_path) and not save_images:
                    os.remove(cropped_path)

        return {"mini_region_dict": mini_region_dict, "text_log": ""}

    logs.append(f"\n{log_prefix}Region {region_idx} - Class ID: {class_id} ({class_name})")
    x_center = (x_min + x_max) // 2
    y_center = (y_min + y_max) // 2
    logs.append(f"{log_prefix}Coordinates: x_center={x_center}, y_center={y_center}")
    width = x_max - x_min
    height = y_max - y_min
    logs.append(f"{log_prefix}Size: width={width}, height={height}")

    region_dict = {
        "id": f"region_{region_idx}_class_{class_name}",
        "x_coordinates_center": x_center,
        "y_coordinates_center": y_center,
        "width": width,
        "height": height
    }

    # Crop region
    cropped_image_region = image[y_min:y_max, x_min:x_max]
    if cropped_image_region.size == 0:
        logs.append(f"{log_prefix}Empty crop for Region {region_idx}, skipping...")
        return {"region_dict": region_dict, "text_log": "\n".join(logs)}

    # Save cropped region
    if class_id == 0:
        # Save as PNG if it's a View
        cropped_path = os.path.join(
            cropped_imageview_images_dir, f"region_{region_idx}_class_{class_id}.png"
        )
        cv2.imwrite(cropped_path, cropped_image_region)
    else:
        # Save as JPG
        cropped_path = os.path.join(
            cropped_imageview_images_dir, f"region_{region_idx}_class_{class_id}.jpg"
        )
        cv2.imwrite(cropped_path, cropped_image_region)

    # for LLaMA (ollama)
    def call_ollama(prompt_text, rid, task_type):
        model_name = "llama3.2-vision:11b"
        cmd = ["ollama", "run", model_name, prompt_text]
        try:
            result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
            if result.returncode != 0:
                logs.append(f"{log_prefix}Error generating {task_type} for Region {rid}: {result.stderr}")
                return None
            else:
                response = result.stdout.strip()
                logs.append(f"{log_prefix}Generated {task_type.capitalize()} for Region {rid}: {response}")
                return response
        except Exception as e:
            logs.append(f"{log_prefix}An error occurred while generating {task_type} for Region {rid}: {e}")
            return None

    # for BLIP-2
    def generate_caption_blip(img_path):
        pil_image = Image.open(img_path).convert('RGB')
        inputs = processor(images=pil_image, return_tensors="pt").to(device, torch.float16)
        gen_ids = model.generate(**inputs)
        return processor.batch_decode(gen_ids, skip_special_tokens=True)[0].strip()

    # Handle each class type
    if class_id == 1:  # ImageView
        if no_captioning:
            logs.append(f"{log_prefix}(Icon-image detection + captioning disabled by --no-captioning.)")
            if not output_json:
                block = (
                    f"Image: region_{region_idx}_class_{class_id} ({class_name})\n"
                    f"Coordinates: x_center={(x_min + x_max) // 2}, y_center={(y_min + y_max) // 2}\n"
                    f"Size: width={width}, height={height}\n"
                    f"{BARRIER}"
                )
                logs.append(block)
        else:
            upscaled = open_and_upscale_image(cropped_path, class_id)
            if upscaled is None:
                return {"region_dict": region_dict, "text_log": "\n".join(logs)}

            # Icon detection
            if icon_model:
                icon_input_size = (224, 224)
                if isinstance(upscaled, Image.Image):
                    upscaled_cv = cv2.cvtColor(np.array(upscaled), cv2.COLOR_RGBA2BGR)
                else:
                    upscaled_cv = upscaled
                resized = cv2.resize(upscaled_cv, icon_input_size)
                rgb_img = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB) / 255.0
                rgb_img = np.expand_dims(rgb_img, axis=0)
                pred = icon_model.predict(rgb_img)
                logs.append(f"{log_prefix}Prediction output for Region {region_idx}: {pred}")
                if pred.shape == (1, 1):
                    probability = pred[0][0]
                    threshold = 0.5
                    predicted_class = 1 if probability >= threshold else 0
                    logs.append(f"{log_prefix}Probability of class 1: {probability}")
                elif pred.shape == (1, 2):
                    predicted_class = np.argmax(pred[0])
                    logs.append(f"{log_prefix}Class probabilities: {pred[0]}")
                else:
                    logs.append(f"{log_prefix}Unexpected prediction shape: {pred.shape}")
                    return {"region_dict": region_dict, "text_log": "\n".join(logs)}

                pred_text = "Icon/Mobile UI Element" if predicted_class == 1 else "Normal Image"
                region_dict["prediction"] = pred_text
                if predicted_class == 1:
                    prompt_text = "Describe the mobile UI element on this image. Keep it short."
                else:
                    prompt_text = "Describe what is in the image briefly. It's not an icon or typical UI element."
            else:
                logs.append(f"{log_prefix}Icon detection model not provided; skipping icon detection.")
                region_dict["prediction"] = "Icon detection skipped"
                prompt_text = "Describe what is in this image briefly."

            # Caption
            temp_image_path = os.path.abspath(
                os.path.join(cropped_imageview_images_dir, f"imageview_{region_idx}.jpg")
            )
            if isinstance(upscaled, Image.Image): # TODO check optimization
                upscaled.save(temp_image_path)
            else:
                cv2.imwrite(temp_image_path, upscaled)

            response = ""
            if model and processor and model_to_use == 'blip':
                response = generate_caption_blip(temp_image_path)
            else: # TODO check optimization
                resp = call_ollama(prompt_text + " " + temp_image_path, region_idx, "description")
                response = resp if resp else "Error generating description"

            region_dict["description"] = response

            if not output_json:
                block = (
                    f"Image: region_{region_idx}_class_{class_id} ({class_name})\n"
                    f"Coordinates: x_center={(x_min + x_max) // 2}, y_center={(y_min + y_max) // 2}\n"
                    f"Size: width={width}, height={height}\n"
                    f"Prediction: {region_dict['prediction']}\n"
                    f"{response}\n"
                    f"{BARRIER}"
                )
                logs.append(block)

            if os.path.exists(temp_image_path) and not save_images:
                os.remove(temp_image_path)

    elif class_id == 2:  # Text
        if skip_ocr or reader is None:
            logs.append(f"{log_prefix}OCR skipped for Region {region_idx}.")

            if not output_json:
                block = (
                    f"Text: region_{region_idx}_class_{class_id} ({class_name})\n"
                    f"Coordinates: x_center={(x_min + x_max) // 2}, "
                    f"y_center={(y_min + y_max) // 2}\n"
                    f"Size: width={width}, height={height}\n"
                    f"OCR + spell-check disabled\n"
                    f"{BARRIER}"
                )
                logs.append(block)

            return {"region_dict": region_dict, "text_log": "\n".join(logs)}

        upscaled = open_and_upscale_image(cropped_path, class_id)
        if upscaled is None:
            return {"region_dict": region_dict, "text_log": "\n".join(logs)}

        if isinstance(upscaled, Image.Image):
            upscaled_cv = cv2.cvtColor(np.array(upscaled), cv2.COLOR_RGBA2BGR)
        else:
            upscaled_cv = upscaled


        # TODO use other lib to improve the performance
        upscaled_cv = downscale_for_ocr(upscaled_cv, max_dim=600)
        gray = cv2.cvtColor(upscaled_cv, cv2.COLOR_BGR2GRAY)
        result_ocr = reader.readtext(gray, detail=0, batch_size=8)
        text = ' '.join(result_ocr).strip()

        # TODO use other lib to improve performance
        if skip_spell or spell is None:
            corrected_text = None
            logs.append(f"{log_prefix}Spell-check skipped for Region {region_idx}.")
        else:
            correction_cache = {}
            corrected_words  = []
            for w in text.split():
                if w not in correction_cache:
                    correction_cache[w] = spell.correction(w) or w
                corrected_words.append(correction_cache[w])
            corrected_text = " ".join(corrected_words)


        logs.append(f"{log_prefix}Extracted Text for Region {region_idx}: {text}")
        if corrected_text is not None:
            logs.append(f"{log_prefix}Corrected Text for Region {region_idx}: {corrected_text}")

        region_dict["extractedText"] = text
        if corrected_text is not None:
            region_dict["correctedText"] = corrected_text

        if not output_json:
            block = (
                f"Text: region_{region_idx}_class_{class_id} ({class_name})\n"
                f"Coordinates: x_center={(x_min + x_max) // 2}, y_center={(y_min + y_max) // 2}\n"
                f"Size: width={width}, height={height}\n"
                f"Extracted Text: {text}\n"
                + (f"Corrected Text: {corrected_text}\n" if corrected_text is not None else "")
                + f"{BARRIER}"
            )
            logs.append(block)

    elif class_id == 0:  # View
        upscaled = open_and_upscale_image(cropped_path, class_id)
        if upscaled is None:
            return {"region_dict": region_dict, "text_log": "\n".join(logs)}

        data = np.array(upscaled)
        if data.ndim == 2:
            data = cv2.cvtColor(data, cv2.COLOR_GRAY2BGRA)
        elif data.shape[-1] == 3:
            b, g, r = cv2.split(data)
            a = np.full_like(b, 255)
            data = cv2.merge((b, g, r, a))

        pixels = data.reshape((-1, 4))
        opaque_pixels = pixels[pixels[:, 3] > 0]

        if len(opaque_pixels) == 0:
            logs.append(f"{log_prefix}No opaque pixels found in Region {region_idx}, cannot determine background color.")
            color_name = "Unknown"
        else:
            dom_color = get_most_frequent_color(opaque_pixels[:, :3], bin_size=10)
            exact_name, closest_name = get_colour_name(dom_color)
            color_name = exact_name if exact_name else closest_name

        alphas = pixels[:, 3]
        dominant_alpha = get_most_frequent_alpha(alphas, bin_size=10)
        transparency = "opaque" if dominant_alpha >= 245 else "transparent"

        response = (
            f"1. The background color of the container is {color_name}.\n"
            f"2. The container is {transparency}."
        )
        logs.append(f"{log_prefix}{response}")
        region_dict["view_color"] = f"The background color of the container is {color_name}."
        region_dict["view_alpha"] = f"The container is {transparency}."

        if not output_json:
            block = (
                f"View: region_{region_idx}_class_{class_id} ({class_name})\n"
                f"Coordinates: x_center={(x_min + x_max) // 2}, y_center={(y_min + y_max) // 2}\n"
                f"Size: width={width}, height={height}\n"
                f"{response}\n"
                f"{BARRIER}"
            )
            logs.append(block)

    elif class_id == 3:  # Line
        logs.append(f"{log_prefix}Processing Line in Region {region_idx}")
        line_img = cv2.imread(cropped_path, cv2.IMREAD_UNCHANGED)
        if line_img is None:
            logs.append(f"{log_prefix}Failed to read image at {cropped_path}")
            return {"region_dict": region_dict, "text_log": "\n".join(logs)}

        hh, ww = line_img.shape[:2]
        logs.append(f"{log_prefix}Image dimensions: width={ww}, height={hh}")

        data = np.array(line_img)
        if data.ndim == 2:
            data = cv2.cvtColor(data, cv2.COLOR_GRAY2BGRA)
        elif data.shape[-1] == 3:
            b, g, r = cv2.split(data)
            a = np.full_like(b, 255)
            data = cv2.merge((b, g, r, a))

        pixels = data.reshape((-1, 4))
        opaque_pixels = pixels[pixels[:, 3] > 0]

        if len(opaque_pixels) == 0:
            logs.append(f"{log_prefix}No opaque pixels found in Region {region_idx}, cannot determine line color.")
            color_name = "Unknown"
        else:
            dom_color = get_most_frequent_color(opaque_pixels[:, :3], bin_size=10)
            exact_name, closest_name = get_colour_name(dom_color)
            color_name = exact_name if exact_name else closest_name

        alphas = pixels[:, 3]
        dom_alpha = get_most_frequent_alpha(alphas, bin_size=10)
        transparency = "opaque" if dom_alpha >= 245 else "transparent"

        response = (
            f"1. The color of the line is {color_name}.\n"
            f"2. The line is {transparency}."
        )
        logs.append(f"{log_prefix}{response}")
        region_dict["line_color"] = f"The color of the line is {color_name}."
        region_dict["line_alpha"] = f"The line is {transparency}."

        if not output_json:
            block = (
                f"Line: region_{region_idx}_class_{class_id} ({class_name})\n"
                f"Coordinates: x_center={(x_min + x_max) // 2}, y_center={(y_min + y_max) // 2}\n"
                f"Size: width={width}, height={height}\n"
                f"{response}\n"
                f"{BARRIER}"
            )
            logs.append(block)

    else:
        logs.append(f"{log_prefix}Class ID {class_id} not handled.")

    # Remove intermediate if not saving
    if os.path.exists(cropped_path) and not save_images:
        os.remove(cropped_path)

    return {
        "region_dict": region_dict,
        "text_log": "\n".join(logs),
    }


# Main function
def process_image(
    input_image_path,
    yolo_output_path,
    output_dir:str = '.',
    model_to_use='llama',     
    save_images=False,
    icon_model_path=None,
    cache_directory='./models_cache',
    huggingface_token='your_token', # for blip2
    no_captioning=False,
    output_json=False,
    json_mini=False,
    sr=None,
    reader=None,
    spell=None,
    skip_ocr=False,
    skip_spell=False
):
    if json_mini:
        json_output = {
            "image_size": None, # Will be populated later
            "bbox_format": "center_x, center_y, width, height",
            "elements": []
        }
    elif output_json:
        json_output = {
            "image": {"path": input_image_path, "size": {"width": None, "height": None}},
            "elements": []
        }
    else:
        json_output = None


    start_time = time.perf_counter()
    print("super-resolution initialization start (in script.py)")
    # Super-resolution initialization
    if sr is None:
        print("No sr reference passed; performing local init ...")
        model_path = 'EDSR_x4.pb'
        if hasattr(cv2, 'dnn_superres'):
            print("dnn_superres module is available.")
            import cv2.dnn_superres as dnn_superres
            try:
                sr = cv2.dnn_superres.DnnSuperResImpl_create()
                print("Using DnnSuperResImpl_create()")
            except AttributeError:
                sr = cv2.dnn_superres.DnnSuperResImpl()
                print("Using DnnSuperResImpl()")
            sr.readModel(model_path)
            sr.setModel('edsr', 4)
        else:
            print("dnn_superres module is NOT available; skipping super-resolution.")
    else:
        print("Using pre-initialized sr reference.")


    elapsed = time.perf_counter() - start_time
    print(f"super-resoulution init (in script.py) took {elapsed:.3f} seconds.")

    start_time = time.perf_counter()

    if skip_ocr:
        print("skip_ocr flag set - skipping EasyOCR and SpellChecker.")
        reader = None
        spell  = None

    else:
        print("OCR initialisation start (in script.py)")
        if reader is None:
            print("No EasyOCR reference passed; performing local init")
            reader = easyocr.Reader(['en'], gpu=True)
        else:
            print("Using pre-initialised EasyOCR object.")

        if skip_spell:
            print("skip_spell flag set - not initialising SpellChecker.")
            spell = None
        else:
            if spell is None:
                print("No SpellChecker reference passed; performing local init")
                spell = SpellChecker()
            else:
                print("Using pre-initialised SpellChecker object.")

    elapsed = time.perf_counter() - start_time
    print(f"OCR init (in script.py) took {elapsed:.3f} seconds.")


    start_time = time.perf_counter()
    print("icon-model init start (in script.py)")
    # Load icon detection model (if provided)
    if icon_model_path:
        icon_model = tf.keras.models.load_model(icon_model_path)
        print(f"Icon detection model loaded: {icon_model_path}")
    else:
        icon_model = None

    elapsed = time.perf_counter() - start_time
    print(f"icon-model init (in script.py) took {elapsed:.3f} seconds.")

    
    # Load the original image
    image = cv2.imread(input_image_path, cv2.IMREAD_UNCHANGED)
    if image is None:
        print(f"Image at {input_image_path} could not be loaded.")
        return

    image_height, image_width = image.shape[:2]


    # Read YOLO labels
    with open(yolo_output_path, 'r') as f:
        lines = f.readlines()

    # Check torch device
    if torch.backends.mps.is_available():
        device = torch.device("mps")
        print("Using MPS")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
        print("Using CUDA")
    else:
        device = torch.device("cpu")
        print("Using CPU")

    # Conditionally load the captioning model
    processor, model = None, None
    if not no_captioning:
        if model_to_use == 'blip':
            print("Loading BLIP-2 model...")
            blip_model_name = "Salesforce/blip2-opt-2.7b"
            if not is_model_downloaded(blip_model_name, cache_directory):
                print("Model not found in cache. Downloading...")
            else:
                print("Model found in cache. Loading...")
            processor = AutoProcessor.from_pretrained(
                blip_model_name,
                use_auth_token=huggingface_token,
                cache_dir=cache_directory,
                resume_download=True
            )
            model = Blip2ForConditionalGeneration.from_pretrained(
                blip_model_name,
                device_map='auto',
                torch_dtype=torch.float16,
                use_auth_token=huggingface_token,
                cache_dir=cache_directory,
                resume_download=True
            ).to(device)
        else:
            print("Using LLaMA model via external call (ollama).")
    else:
        print("--no-captioning flag is set; skipping model loading.")

    # Prepare bounding boxes from YOLO
    bounding_boxes = []
    for line in lines:
        parts = line.strip().split()
        class_id = int(parts[0])
        x_center_norm, y_center_norm, width_norm, height_norm = map(float, parts[1:])
        x_center = x_center_norm * image_width
        y_center = y_center_norm * image_height
        box_width = width_norm * image_width
        box_height = height_norm * image_height
        x_min = int(x_center - box_width / 2)
        y_min = int(y_center - box_height / 2)
        x_max = int(x_center + box_width / 2)
        y_max = int(y_center + box_height / 2)
        x_min = max(0, x_min)
        y_min = max(0, y_min)
        x_max = min(image_width - 1, x_max)
        y_max = min(image_height - 1, y_max)
        bounding_boxes.append((x_min, y_min, x_max, y_max, class_id))


    # Create output dirs
    cropped_dir = os.path.join(output_dir, "cropped_imageview_images")
    os.makedirs(cropped_dir, exist_ok=True)
    result_dir = os.path.join(output_dir, "result")
    os.makedirs(result_dir, exist_ok=True)

    base_name = os.path.splitext(os.path.basename(input_image_path))[0]
    captions_file_path = None
    if json_mini:
        json_output["image_size"] = NoIndent([image_width, image_height])
    elif output_json:
        json_output["image"]["size"]["width"] = image_width
        json_output["image"]["size"]["height"] = image_height
    else: # Text output
        captions_filename = f"{base_name}_regions_captions.txt"
        captions_file_path = os.path.join(result_dir, captions_filename)
        with open(captions_file_path, 'w', encoding='utf-8') as f:
            f.write(f"Image path: {input_image_path}\n")
            f.write(f"Image Size: width={image_width}, height={image_height}\n")
            f.write(BARRIER)

    # Number of workers can be increased if hardware is suitable for it. But testing is needed
    start_time = time.perf_counter()
    print("Process single region start (in script.py)")

    with ThreadPoolExecutor(max_workers=1) as executor:
        futures = [
            executor.submit(
                process_single_region,
                idx, box, image, sr, reader, spell,
                icon_model, processor, model, (model and device),
                no_captioning, output_json, json_mini,
                cropped_dir, base_name, save_images,
                model_to_use, log_prefix="",
                skip_ocr=skip_ocr,
                skip_spell=skip_spell
            ) for idx, box in enumerate(bounding_boxes)
        ]

        for future in as_completed(futures):
            item = future.result()
            if json_mini:
                if item.get("mini_region_dict"):
                    json_output["elements"].append(item["mini_region_dict"])
            elif output_json:
                if item.get("region_dict"):
                    json_output["elements"].append(item["region_dict"])
            else: # Text output
                if item.get("text_log") and captions_file_path:
                    with open(captions_file_path, 'a', encoding='utf-8') as f:
                        f.write(item["text_log"])

    elapsed = time.perf_counter() - start_time
    print(f"Processing regions took {elapsed:.3f} seconds.")

    if json_mini or output_json:
        json_file = os.path.join(result_dir, f"{base_name}.json")
        with open(json_file, 'w', encoding='utf-8') as f:
            json.dump(json_output, f, indent=2, ensure_ascii=False, cls=CustomEncoder)
        
        output_type = "mini JSON" if json_mini else "JSON"
        print(f"{output_type} output written to {json_file}")
    else:
        print(f"Text output written to {captions_file_path}")
    

# CLI entry point
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Process an image and its YOLO labels.')
    parser.add_argument('input_image', help='Path to the input YOLO image.')
    parser.add_argument('input_labels', help='Path to the input YOLO labels file.')
    parser.add_argument('--output_dir', default='.', 
                        help='Directory to save output files. Defaults to the current directory.')
    parser.add_argument('--model_to_use', choices=['llama', 'blip'], default='llama',
                        help='Model for captioning (llama or blip).')
    parser.add_argument('--save_images', action='store_true',
                        help='Flag to save intermediate images.')
    parser.add_argument('--icon_detection_path', help='Path to icon detection model.')
    parser.add_argument('--cache_directory', default='./models_cache',
                        help='Cache directory for Hugging Face models.')
    parser.add_argument('--huggingface_token', default='your_token',
                        help='Hugging Face token for model downloads.')
    parser.add_argument('--no-captioning', action='store_true',
                        help='Disable any image captioning.')
    parser.add_argument('--json', dest='output_json', action='store_true',
                        help='Output the image data in JSON format')
    parser.add_argument('--json-mini', action='store_true',
                        help='Output the image data in a condensed JSON format')
    args = parser.parse_args()

    process_image(
        input_image_path=args.input_image,
        yolo_output_path=args.input_labels,
        output_dir=args.output_dir,
        model_to_use=args.model_to_use,
        save_images=args.save_images,
        icon_model_path=args.icon_detection_path,
        cache_directory=args.cache_directory,
        huggingface_token=args.huggingface_token,
        no_captioning=args.no_captioning,
        output_json=args.output_json,
        json_mini=args.json_mini
    )