FrankFacundo commited on
Commit
dd559e1
·
1 Parent(s): 46af7ed
Files changed (1) hide show
  1. app.py +39 -13
app.py CHANGED
@@ -4,6 +4,7 @@ import spaces
4
  import os
5
  import random
6
 
 
7
  import torch
8
  from PIL import Image
9
  import cv2
@@ -32,6 +33,8 @@ Before running, set the `HUGGINGFACE_TOKEN` environment variable **or** call
32
  `login("<YOUR_HF_TOKEN>")` explicitly.
33
  """
34
 
 
 
35
  # --------------------------------------------------
36
  # Model & pipeline setup
37
  # --------------------------------------------------
@@ -76,16 +79,23 @@ MAX_SEED = 100
76
  # --------------------------------------------------
77
 
78
 
79
- def _preview_canny(pil_img: Image.Image) -> Image.Image:
 
 
80
  arr = np.array(pil_img.convert("RGB"))
81
- edges = cv2.Canny(arr, 100, 200)
82
  edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
83
  return Image.fromarray(edges_rgb)
84
 
85
 
86
- def _make_preview(control_image: Image.Image, mode: str) -> Image.Image:
 
 
 
 
 
87
  if mode == "canny":
88
- return _preview_canny(control_image)
89
  # For other modes you can plug in your own visualiser later
90
  return control_image
91
 
@@ -105,6 +115,8 @@ def infer(
105
  randomize_seed: bool,
106
  guidance_scale: float,
107
  num_inference_steps: int,
 
 
108
  ):
109
  if control_image is None:
110
  raise gr.Error("Please upload a control image first.")
@@ -115,8 +127,12 @@ def infer(
115
  gen = torch.Generator(device).manual_seed(seed)
116
  w, h = control_image.size
117
 
 
 
 
 
118
  result = pipe(
119
- prompt=prompt,
120
  control_image=[control_image],
121
  control_mode=[MODE_MAPPING[mode]],
122
  width=w,
@@ -127,8 +143,7 @@ def infer(
127
  generator=gen,
128
  ).images[0]
129
 
130
- preview = _make_preview(control_image, mode)
131
- return result, seed, preview
132
 
133
 
134
  # --------------------------------------------------
@@ -148,23 +163,23 @@ with gr.Blocks(css=css, elem_id="wrapper") as demo:
148
  control_image = gr.Image(
149
  label="Upload a processed control image",
150
  type="pil",
151
- height=512,
152
  )
153
- result_image = gr.Image(label="Result", height=512)
154
- preview_image = gr.Image(label="Pre‑processed Cond", height=512)
155
 
156
  # ------------ Prompt ------------
157
- prompt_txt = gr.Textbox(label="Prompt", value="best quality", lines=1)
158
 
159
  # ------------ ControlNet settings ------------
160
  with gr.Row():
161
  with gr.Column():
162
  gr.Markdown("### ControlNet")
163
  mode_radio = gr.Radio(
164
- choices=list(MODE_MAPPING.keys()), value="gray", label="Mode"
165
  )
166
  strength_slider = gr.Slider(
167
- 0.0, 1.0, value=0.5, step=0.01, label="control strength"
168
  )
169
  with gr.Column():
170
  seed_slider = gr.Slider(0, MAX_SEED, step=1, value=42, label="Seed")
@@ -174,6 +189,15 @@ with gr.Blocks(css=css, elem_id="wrapper") as demo:
174
  )
175
  steps_slider = gr.Slider(1, 50, step=1, value=24, label="Inference steps")
176
 
 
 
 
 
 
 
 
 
 
177
  submit_btn = gr.Button("Submit")
178
 
179
  submit_btn.click(
@@ -187,6 +211,8 @@ with gr.Blocks(css=css, elem_id="wrapper") as demo:
187
  randomize_chk,
188
  guidance_slider,
189
  steps_slider,
 
 
190
  ],
191
  outputs=[result_image, seed_slider, preview_image],
192
  )
 
4
  import os
5
  import random
6
 
7
+ import subprocess
8
  import torch
9
  from PIL import Image
10
  import cv2
 
33
  `login("<YOUR_HF_TOKEN>")` explicitly.
34
  """
35
 
36
+ subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
37
+
38
  # --------------------------------------------------
39
  # Model & pipeline setup
40
  # --------------------------------------------------
 
79
  # --------------------------------------------------
80
 
81
 
82
+ def _preview_canny(
83
+ pil_img: Image.Image, canny_threshold_1: int, canny_threshold_2: int
84
+ ) -> Image.Image:
85
  arr = np.array(pil_img.convert("RGB"))
86
+ edges = cv2.Canny(arr, threshold1=canny_threshold_1, threshold2=canny_threshold_2)
87
  edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
88
  return Image.fromarray(edges_rgb)
89
 
90
 
91
+ def _make_preview(
92
+ control_image: Image.Image,
93
+ mode: str,
94
+ canny_threshold_1: int,
95
+ canny_threshold_2: int,
96
+ ) -> Image.Image:
97
  if mode == "canny":
98
+ return _preview_canny(control_image, canny_threshold_1, canny_threshold_2)
99
  # For other modes you can plug in your own visualiser later
100
  return control_image
101
 
 
115
  randomize_seed: bool,
116
  guidance_scale: float,
117
  num_inference_steps: int,
118
+ canny_threshold_1: int,
119
+ canny_threshold_2: int,
120
  ):
121
  if control_image is None:
122
  raise gr.Error("Please upload a control image first.")
 
127
  gen = torch.Generator(device).manual_seed(seed)
128
  w, h = control_image.size
129
 
130
+ preprocessed = _make_preview(
131
+ control_image, mode, canny_threshold_1, canny_threshold_2
132
+ )
133
+
134
  result = pipe(
135
+ prompt=preprocessed,
136
  control_image=[control_image],
137
  control_mode=[MODE_MAPPING[mode]],
138
  width=w,
 
143
  generator=gen,
144
  ).images[0]
145
 
146
+ return result, seed, preprocessed
 
147
 
148
 
149
  # --------------------------------------------------
 
163
  control_image = gr.Image(
164
  label="Upload a processed control image",
165
  type="pil",
166
+ height=512 + 256,
167
  )
168
+ result_image = gr.Image(label="Result", height=512 + 256)
169
+ preview_image = gr.Image(label="Pre‑processed Cond", height=512 + 256)
170
 
171
  # ------------ Prompt ------------
172
+ prompt_txt = gr.Textbox(label="Prompt", value="A beautiful image", lines=1)
173
 
174
  # ------------ ControlNet settings ------------
175
  with gr.Row():
176
  with gr.Column():
177
  gr.Markdown("### ControlNet")
178
  mode_radio = gr.Radio(
179
+ choices=list(MODE_MAPPING.keys()), value="canny", label="Mode"
180
  )
181
  strength_slider = gr.Slider(
182
+ 0.0, 1.0, value=0.8, step=0.01, label="control strength"
183
  )
184
  with gr.Column():
185
  seed_slider = gr.Slider(0, MAX_SEED, step=1, value=42, label="Seed")
 
189
  )
190
  steps_slider = gr.Slider(1, 50, step=1, value=24, label="Inference steps")
191
 
192
+ with gr.Row():
193
+ with gr.Column():
194
+ gr.Markdown("### Preprocess")
195
+ canny_threshold_1 = gr.Slider(
196
+ 0, 500, step=1, value=100, label="Canny threshold 1"
197
+ )
198
+ canny_threshold_2 = gr.Slider(
199
+ 0, 500, step=1, value=200, label="Canny threshold 2"
200
+ )
201
  submit_btn = gr.Button("Submit")
202
 
203
  submit_btn.click(
 
211
  randomize_chk,
212
  guidance_slider,
213
  steps_slider,
214
+ canny_threshold_1,
215
+ canny_threshold_2,
216
  ],
217
  outputs=[result_image, seed_slider, preview_image],
218
  )