File size: 3,257 Bytes
11e836c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import cv2 as cv
import numpy as np
import argparse
from lama import Lama

def get_args_parser(func_args):
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument('--input', help='Path to input image', default=0, required=False)
    parser.add_argument('--model', help='Path to lama onnx', default='inpainting_lama_2025jan.onnx', required=False)

    parser = argparse.ArgumentParser(parents=[parser],
                                     description='', formatter_class=argparse.RawTextHelpFormatter)
    return parser.parse_args(func_args)

drawing = False
mask_gray = None
brush_size = 15

def draw_mask(event, x, y, flags, param):
    global drawing, mask_gray, brush_size
    if event == cv.EVENT_LBUTTONDOWN:
        drawing = True
    elif event == cv.EVENT_MOUSEMOVE:
        if drawing:
            cv.circle(mask_gray, (x, y), brush_size, (255), thickness=-1)
    elif event == cv.EVENT_LBUTTONUP:
        drawing = False

def main(func_args=None):
    global mask_gray, brush_size
    args = get_args_parser(func_args)

    lama = Lama(modelPath=args.model)
    input_image = cv.imread(args.input)
    mask_gray = np.zeros((input_image.shape[0], input_image.shape[1]), dtype=np.uint8)

    stdSize = 0.6
    stdWeight = 2
    stdImgSize = 512
    imgWidth = min(input_image.shape[:2])
    fontSize = min(1.5, (stdSize*imgWidth)/stdImgSize)
    fontThickness = max(1,(stdWeight*imgWidth)//stdImgSize)

    cv.namedWindow("Draw Mask")
    cv.setMouseCallback("Draw Mask", draw_mask)

    label = "Draw the mask on the image. Press space bar when done."
    labelSize, _ = cv.getTextSize(label, cv.FONT_HERSHEY_SIMPLEX, fontSize, fontThickness)
    while True:
        display_image = input_image.copy()
        overlay = input_image.copy()

        alpha = 0.5
        cv.rectangle(overlay, (0, 0), (labelSize[0]+10, labelSize[1]+int(30*fontSize)), (255, 255, 255), cv.FILLED)
        cv.addWeighted(overlay, alpha, display_image, 1 - alpha, 0, display_image)

        cv.putText(display_image, label, (10, int(25*fontSize)), cv.FONT_HERSHEY_SIMPLEX, fontSize, (0, 0, 0), fontThickness)
        cv.putText(display_image, "Press 'i' to increase and 'd' to decrease brush size.", (10, int(50*fontSize)), cv.FONT_HERSHEY_SIMPLEX, fontSize, (0, 0, 0), fontThickness)
        display_image[mask_gray > 0] = [255, 255, 255]
        cv.imshow("Draw Mask", display_image)

        key = cv.waitKey(1) & 0xFF
        if key == ord('i'):  # Increase brush size
            brush_size += 1
            print(f"Brush size increased to {brush_size}")
        elif key == ord('d'):  # Decrease brush size
            brush_size = max(1, brush_size - 1)
            print(f"Brush size decreased to {brush_size}")
        elif key == ord(' '): # Press space bar to finish drawing
            break
        elif key == 27:
            exit()
    cv.destroyAllWindows()

    tm = cv.TickMeter()
    tm.start()
    result = lama.infer(input_image, mask_gray)
    tm.stop()
    label = 'Inference time: {:.2f} ms'.format(tm.getTimeMilli())
    cv.putText(result, label, (0, 15), cv.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0))

    cv.imshow("Inpainted Output", result)
    cv.waitKey(0)
    cv.destroyAllWindows()

if __name__ == '__main__':
    main()