File size: 3,439 Bytes
fa1528d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import replicate
import os
import requests
import tempfile
import logging
import base64

logging.basicConfig(
    level=logging.DEBUG,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

def process_image(password, input_image):
    # ํ™˜๊ฒฝ๋ณ€์ˆ˜์—์„œ ๋น„๋ฐ€๋ฒˆํ˜ธ ๊ฐ€์ ธ์˜ค๊ธฐ
    correct_password = os.getenv("APP_PASSWORD")
    if not correct_password:
        logger.error("APP_PASSWORD ํ™˜๊ฒฝ๋ณ€์ˆ˜๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
        raise ValueError("์„œ๋ฒ„ ์„ค์ • ์˜ค๋ฅ˜์ž…๋‹ˆ๋‹ค.")
    
    # ๋น„๋ฐ€๋ฒˆํ˜ธ ๊ฒ€์ฆ
    if password != correct_password:
        raise ValueError("์ž˜๋ชป๋œ ๋น„๋ฐ€๋ฒˆํ˜ธ์ž…๋‹ˆ๋‹ค.")
    
    # Replicate API ํ† ํฐ ํ™•์ธ
    if not os.getenv("REPLICATE_API_TOKEN"):
        logger.error("REPLICATE_API_TOKEN์ด ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
        return None, None
    
    # ํ™˜๊ฒฝ๋ณ€์ˆ˜์—์„œ ๋ชจ๋ธ ์ด๋ฆ„ ๊ฐ€์ ธ์˜ค๊ธฐ
    model_name = os.getenv("REPLICATE_MODEL")
    if not model_name:
        logger.error("REPLICATE_MODEL ํ™˜๊ฒฝ๋ณ€์ˆ˜๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
        return None, None
    
    if input_image is None:
        logger.error("์ž…๋ ฅ ์ด๋ฏธ์ง€๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.")
        return None, None
    
    try:
        # base64๋กœ ๋ณ€ํ™˜
        with open(input_image, "rb") as f:
            data = base64.b64encode(f.read()).decode()
        
        image_uri = f"data:image/png;base64,{data}"
        
        # ํ™˜๊ฒฝ๋ณ€์ˆ˜์—์„œ ๊ฐ€์ ธ์˜จ ๋ชจ๋ธ๋กœ ์‹คํ–‰
        output = replicate.run(
            model_name,
            input={"image": image_uri}
        )
        
        # ๊ฒฐ๊ณผ ์ €์žฅ
        with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
            if hasattr(output, "read"):
                tmp.write(output.read())
            elif isinstance(output, (list, tuple)) and output:
                resp = requests.get(output[0])
                resp.raise_for_status()
                tmp.write(resp.content)
            
            out_path = tmp.name
        
        return out_path, out_path
        
    except replicate.exceptions.ReplicateError as re:
        logger.error(f"API ์˜ค๋ฅ˜: {re}")
        return None, None
    except Exception as e:
        logger.error(f"์˜ˆ์ƒ์น˜ ๋ชปํ•œ ์˜ค๋ฅ˜: {e}")
        return None, None

# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    with gr.Row():
        with gr.Column():
            password_input = gr.Textbox(
                label="๋น„๋ฐ€๋ฒˆํ˜ธ",
                type="password",
                placeholder="๋น„๋ฐ€๋ฒˆํ˜ธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”"
            )
            input_image = gr.Image(
                type="filepath",
                label="์ž…๋ ฅ ์ด๋ฏธ์ง€",
                interactive=True
            )
            process_btn = gr.Button("์‹คํ–‰", variant="primary")
            
        with gr.Column():
            output_image = gr.Image(
                type="filepath",
                label="๊ฒฐ๊ณผ ์ด๋ฏธ์ง€",
                interactive=False
            )
            download_btn = gr.DownloadButton(
                label="์ด๋ฏธ์ง€ ๋‹ค์šด๋กœ๋“œ"
            )
    
    # click ์‹œ (ํ‘œ์‹œ์šฉ, ๋‹ค์šด๋กœ๋“œ์šฉ) ๋‘ ๊ฐœ ๋ฆฌํ„ด
    process_btn.click(
        fn=process_image,
        inputs=[password_input, input_image],
        outputs=[output_image, download_btn]
    )
    
    gr.Markdown("---")

if __name__ == "__main__":
    demo.launch()