cp524 commited on
Commit
3e0672b
·
1 Parent(s): 9712fd8

Add finetuned model

Browse files
Files changed (3) hide show
  1. app.py +55 -10
  2. requirements.txt +2 -1
  3. src/smc/inference.py +65 -0
app.py CHANGED
@@ -7,8 +7,10 @@ import gradio as gr
7
  from src.smc.inference import (
8
  infer_pretrained,
9
  infer_smc_grad,
 
10
  PretrainedInferenceConfig,
11
  SMCGradInferenceConfig,
 
12
  )
13
 
14
  def get_device():
@@ -45,11 +47,7 @@ def _format_inference_output(out) -> str:
45
 
46
  # --- Per-method runner functions ---
47
  def run_pretrained_ui(prompt, pretrained_negative_prompt, pretrained_CFG, pretrained_steps):
48
- """Run the pretrained inference method and return (gallery, info).
49
-
50
- This function is designed to be attached directly to a Gradio event so it can
51
- execute independently and return only the components it owns.
52
- """
53
  try:
54
  pretrained_cfg = PretrainedInferenceConfig(
55
  prompt=prompt,
@@ -110,6 +108,25 @@ def run_smc_grad_ui(
110
  traceback.print_exc()
111
  err_msg = f"SMC-grad inference error: {e}"
112
  return [err_msg], err_msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
 
115
  def mark_all_running():
@@ -121,7 +138,7 @@ def mark_all_running():
121
  running_info = gr.update(value="Running...", interactive=False)
122
  empty_gallery = gr.update(value=[])
123
  # Return values must match the components this function is attached to (see below)
124
- return empty_gallery, running_info, empty_gallery, running_info
125
 
126
 
127
  with gr.Blocks() as demo:
@@ -136,7 +153,7 @@ with gr.Blocks() as demo:
136
  # --- Pretrained method row ---
137
  with gr.Row():
138
  with gr.Column(scale=1, min_width=280):
139
- with gr.Accordion("Pretrained method — settings", open=False):
140
  pretrained_negative_prompt = gr.Textbox(
141
  label="Negative prompt", value=PretrainedInferenceConfig.negative_prompt, lines=1
142
  )
@@ -145,7 +162,7 @@ with gr.Blocks() as demo:
145
 
146
  with gr.Column(scale=2):
147
  pretrained_gallery = gr.Gallery(
148
- label="Pretrained outputs", show_label=True, elem_id="pretrained_gallery", height="240px", columns=4,
149
  object_fit="contain",
150
  )
151
  pretrained_info = gr.Textbox(label="Pretrained info", interactive=False)
@@ -192,13 +209,30 @@ with gr.Blocks() as demo:
192
  object_fit="contain",
193
  )
194
  smc_grad_info = gr.Textbox(label="SMC-grad info", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  # --- Wiring ---
197
  # 1) Quick 'running' update attached to the button so the UI shows immediate feedback.
198
  run_button.click(
199
  fn=mark_all_running,
200
  inputs=[],
201
- outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info],
202
  )
203
 
204
  # 2) Attach the per-method heavy functions separately. Gradio's queue() will allow
@@ -229,12 +263,18 @@ with gr.Blocks() as demo:
229
  ],
230
  outputs=[smc_grad_gallery, smc_grad_info],
231
  )
 
 
 
 
 
 
232
 
233
  # Also allow pressing Enter in the prompt to trigger the same set of handlers
234
  prompt.submit(
235
  fn=mark_all_running,
236
  inputs=[],
237
- outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info],
238
  )
239
  prompt.submit(
240
  fn=run_pretrained_ui,
@@ -261,6 +301,11 @@ with gr.Blocks() as demo:
261
  ],
262
  outputs=[smc_grad_gallery, smc_grad_info],
263
  )
 
 
 
 
 
264
 
265
  # Enable Gradio queue to allow parallel execution of multiple handlers. Set concurrency
266
  # to 2 (one per method) — increase if you add more methods.
 
7
  from src.smc.inference import (
8
  infer_pretrained,
9
  infer_smc_grad,
10
+ infer_ft,
11
  PretrainedInferenceConfig,
12
  SMCGradInferenceConfig,
13
+ FTInferenceConfig,
14
  )
15
 
16
  def get_device():
 
47
 
48
  # --- Per-method runner functions ---
49
  def run_pretrained_ui(prompt, pretrained_negative_prompt, pretrained_CFG, pretrained_steps):
50
+ """Run the pretrained inference method and return (gallery, info)."""
 
 
 
 
51
  try:
52
  pretrained_cfg = PretrainedInferenceConfig(
53
  prompt=prompt,
 
108
  traceback.print_exc()
109
  err_msg = f"SMC-grad inference error: {e}"
110
  return [err_msg], err_msg
111
+
112
+ def run_ft_ui(prompt, ft_negative_prompt, ft_CFG, ft_steps):
113
+ """Run the finetuned model inference and return (gallery, info)."""
114
+ try:
115
+ ft_cfg = FTInferenceConfig(
116
+ prompt=prompt,
117
+ negative_prompt=ft_negative_prompt or "",
118
+ CFG=float(ft_CFG),
119
+ steps=int(ft_steps),
120
+ )
121
+ out = infer_ft(ft_cfg, device=get_device())
122
+ gallery = out.images
123
+ info = _format_inference_output(out)
124
+ return gallery, info
125
+ except Exception as e:
126
+ traceback.print_exc()
127
+ err_msg = f"FT inference error: {e}"
128
+ # Return a simple textual error in the gallery and the info box
129
+ return [err_msg], err_msg
130
 
131
 
132
  def mark_all_running():
 
138
  running_info = gr.update(value="Running...", interactive=False)
139
  empty_gallery = gr.update(value=[])
140
  # Return values must match the components this function is attached to (see below)
141
+ return empty_gallery, running_info, empty_gallery, running_info, empty_gallery, running_info
142
 
143
 
144
  with gr.Blocks() as demo:
 
153
  # --- Pretrained method row ---
154
  with gr.Row():
155
  with gr.Column(scale=1, min_width=280):
156
+ with gr.Accordion("Pretrained model — settings", open=False):
157
  pretrained_negative_prompt = gr.Textbox(
158
  label="Negative prompt", value=PretrainedInferenceConfig.negative_prompt, lines=1
159
  )
 
162
 
163
  with gr.Column(scale=2):
164
  pretrained_gallery = gr.Gallery(
165
+ label="Pretrained model outputs", show_label=True, elem_id="pretrained_gallery", height="240px", columns=4,
166
  object_fit="contain",
167
  )
168
  pretrained_info = gr.Textbox(label="Pretrained info", interactive=False)
 
209
  object_fit="contain",
210
  )
211
  smc_grad_info = gr.Textbox(label="SMC-grad info", interactive=False)
212
+
213
+ # --- FT method row ---
214
+ with gr.Row():
215
+ with gr.Column(scale=1, min_width=280):
216
+ with gr.Accordion("Finetuned model — settings", open=False):
217
+ ft_negative_prompt = gr.Textbox(
218
+ label="Negative prompt", value=FTInferenceConfig.negative_prompt, lines=1
219
+ )
220
+ ft_CFG = gr.Slider(0.0, 30.0, step=0.1, value=FTInferenceConfig.CFG, label="CFG")
221
+ ft_steps = gr.Slider(1, 200, step=1, value=FTInferenceConfig.steps, label="Steps")
222
+
223
+ with gr.Column(scale=2):
224
+ ft_gallery = gr.Gallery(
225
+ label="Finetuned model outputs", show_label=True, elem_id="ft_gallery", height="240px", columns=4,
226
+ object_fit="contain",
227
+ )
228
+ ft_info = gr.Textbox(label="Finetuned info", interactive=False)
229
 
230
  # --- Wiring ---
231
  # 1) Quick 'running' update attached to the button so the UI shows immediate feedback.
232
  run_button.click(
233
  fn=mark_all_running,
234
  inputs=[],
235
+ outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info, ft_gallery, ft_info],
236
  )
237
 
238
  # 2) Attach the per-method heavy functions separately. Gradio's queue() will allow
 
263
  ],
264
  outputs=[smc_grad_gallery, smc_grad_info],
265
  )
266
+
267
+ run_button.click(
268
+ fn=run_ft_ui,
269
+ inputs=[prompt, ft_negative_prompt, ft_CFG, ft_steps],
270
+ outputs=[ft_gallery, ft_info],
271
+ )
272
 
273
  # Also allow pressing Enter in the prompt to trigger the same set of handlers
274
  prompt.submit(
275
  fn=mark_all_running,
276
  inputs=[],
277
+ outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info, ft_gallery, ft_info],
278
  )
279
  prompt.submit(
280
  fn=run_pretrained_ui,
 
301
  ],
302
  outputs=[smc_grad_gallery, smc_grad_info],
303
  )
304
+ prompt.submit(
305
+ fn=run_ft_ui,
306
+ inputs=[prompt, ft_negative_prompt, ft_CFG, ft_steps],
307
+ outputs=[ft_gallery, ft_info],
308
+ )
309
 
310
  # Enable Gradio queue to allow parallel execution of multiple handlers. Set concurrency
311
  # to 2 (one per method) — increase if you add more methods.
requirements.txt CHANGED
@@ -7,4 +7,5 @@ xformers
7
  gradio
8
  spaces
9
  image-reward
10
- openai-clip
 
 
7
  gradio
8
  spaces
9
  image-reward
10
+ openai-clip
11
+ peft
src/smc/inference.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import math
2
  import threading
3
  import spaces
@@ -25,6 +26,7 @@ MIN_GPU_DURATION = 60
25
  pipe_build_lock = threading.Lock()
26
  pipe_load_lock = threading.Lock()
27
  reward_model_load_lock = threading.Lock()
 
28
 
29
 
30
  def build_pipe(device):
@@ -43,6 +45,13 @@ def build_pipe(device):
43
  pipe = Pipeline(vq_model, tokenizer=tokenizer, text_encoder=text_encoder, transformer=model, scheduler=scheduler_new)
44
  return pipe
45
 
 
 
 
 
 
 
 
46
  @dataclass
47
  class InferenceOutput:
48
  images: List[Image.Image]
@@ -205,3 +214,59 @@ def infer_smc_grad_with_pipe(config: SMCGradInferenceConfig, pipe: Pipeline, dev
205
  pil_images: List[Image.Image] = pipe.image_processor.postprocess(images, "pil") # type: ignore
206
  gpu_mem_used = torch.cuda.max_memory_allocated(device) / 1024**3
207
  return InferenceOutput(images=pil_images, image_rewards=image_rewards, gpu_mem_used=gpu_mem_used)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import math
3
  import threading
4
  import spaces
 
26
  pipe_build_lock = threading.Lock()
27
  pipe_load_lock = threading.Lock()
28
  reward_model_load_lock = threading.Lock()
29
+ lora_load_lock = threading.Lock()
30
 
31
 
32
  def build_pipe(device):
 
45
  pipe = Pipeline(vq_model, tokenizer=tokenizer, text_encoder=text_encoder, transformer=model, scheduler=scheduler_new)
46
  return pipe
47
 
48
+ def load_lora_weights(pipe, lora_ckpt_uuid):
49
+ # LORA lora checkpoint
50
+ ckpt_path = os.path.join('checkpoints', lora_ckpt_uuid)
51
+ pipe.load_lora_weights(
52
+ pretrained_model_name_or_path_or_dict=ckpt_path,
53
+ )
54
+
55
  @dataclass
56
  class InferenceOutput:
57
  images: List[Image.Image]
 
214
  pil_images: List[Image.Image] = pipe.image_processor.postprocess(images, "pil") # type: ignore
215
  gpu_mem_used = torch.cuda.max_memory_allocated(device) / 1024**3
216
  return InferenceOutput(images=pil_images, image_rewards=image_rewards, gpu_mem_used=gpu_mem_used)
217
+
218
+ @dataclass
219
+ class FTInferenceConfig:
220
+ prompt: str
221
+ negative_prompt: str = "worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark"
222
+ resolution: int = 512
223
+ CFG: float = 9.0
224
+ steps: int = 48
225
+ num_batches: int = 4
226
+ ckpt_uuid: str = "a1e906e1-16a9-44a3-abe8-6dd2c17e12a2"
227
+
228
+ def infer_ft(config: FTInferenceConfig, device='cpu'):
229
+ with pipe_build_lock:
230
+ pipe = build_pipe(device)
231
+ return infer_ft_with_pipe(config, pipe, device=device)
232
+
233
+ def _get_ft_duration(config: FTInferenceConfig, pipe: Pipeline, device='cpu') -> int:
234
+ setup_duration = 30.0
235
+ step_duration = 1.0
236
+ total_duration = math.ceil(setup_duration + step_duration * config.steps)
237
+ return max(total_duration, MIN_GPU_DURATION)
238
+
239
+ @spaces.GPU(duration=_get_ft_duration)
240
+ def infer_ft_with_pipe(config: FTInferenceConfig, pipe: Pipeline, device='cpu'):
241
+ if isinstance(device, str):
242
+ device = torch.device(device)
243
+ with pipe_load_lock:
244
+ pipe = pipe.to(device)
245
+ with lora_load_lock:
246
+ load_lora_weights(pipe, config.ckpt_uuid)
247
+ reward_bias = 5.0
248
+ with reward_model_load_lock:
249
+ reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, bias=reward_bias), "image_reward_plus_5"
250
+ image_reward_fn = lambda images: reward_fn(
251
+ images,
252
+ [config.prompt] * len(images)
253
+ )
254
+ images = pipe(
255
+ prompt=config.prompt,
256
+ reward_fn=image_reward_fn,
257
+ resample_fn=lambda log_w: resample(log_w),
258
+ negative_prompt=config.negative_prompt,
259
+ height=config.resolution,
260
+ width=config.resolution,
261
+ guidance_scale=config.CFG,
262
+ num_inference_steps=config.steps,
263
+ batches=config.num_batches,
264
+ num_particles=1,
265
+ batch_p=config.num_batches,
266
+ proposal_type="without_SMC",
267
+ output_type="pt",
268
+ )
269
+ image_rewards = (image_reward_fn(images) - reward_bias).tolist()
270
+ pil_images: List[Image.Image] = pipe.image_processor.postprocess(images, "pil") # type: ignore
271
+ gpu_mem_used = torch.cuda.max_memory_allocated(device) / 1024**3
272
+ return InferenceOutput(images=pil_images, image_rewards=image_rewards, gpu_mem_used=gpu_mem_used)