Spaces:
Running
on
Zero
Running
on
Zero
Add finetuned model
Browse files- app.py +55 -10
- requirements.txt +2 -1
- src/smc/inference.py +65 -0
app.py
CHANGED
|
@@ -7,8 +7,10 @@ import gradio as gr
|
|
| 7 |
from src.smc.inference import (
|
| 8 |
infer_pretrained,
|
| 9 |
infer_smc_grad,
|
|
|
|
| 10 |
PretrainedInferenceConfig,
|
| 11 |
SMCGradInferenceConfig,
|
|
|
|
| 12 |
)
|
| 13 |
|
| 14 |
def get_device():
|
|
@@ -45,11 +47,7 @@ def _format_inference_output(out) -> str:
|
|
| 45 |
|
| 46 |
# --- Per-method runner functions ---
|
| 47 |
def run_pretrained_ui(prompt, pretrained_negative_prompt, pretrained_CFG, pretrained_steps):
|
| 48 |
-
"""Run the pretrained inference method and return (gallery, info).
|
| 49 |
-
|
| 50 |
-
This function is designed to be attached directly to a Gradio event so it can
|
| 51 |
-
execute independently and return only the components it owns.
|
| 52 |
-
"""
|
| 53 |
try:
|
| 54 |
pretrained_cfg = PretrainedInferenceConfig(
|
| 55 |
prompt=prompt,
|
|
@@ -110,6 +108,25 @@ def run_smc_grad_ui(
|
|
| 110 |
traceback.print_exc()
|
| 111 |
err_msg = f"SMC-grad inference error: {e}"
|
| 112 |
return [err_msg], err_msg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
|
| 115 |
def mark_all_running():
|
|
@@ -121,7 +138,7 @@ def mark_all_running():
|
|
| 121 |
running_info = gr.update(value="Running...", interactive=False)
|
| 122 |
empty_gallery = gr.update(value=[])
|
| 123 |
# Return values must match the components this function is attached to (see below)
|
| 124 |
-
return empty_gallery, running_info, empty_gallery, running_info
|
| 125 |
|
| 126 |
|
| 127 |
with gr.Blocks() as demo:
|
|
@@ -136,7 +153,7 @@ with gr.Blocks() as demo:
|
|
| 136 |
# --- Pretrained method row ---
|
| 137 |
with gr.Row():
|
| 138 |
with gr.Column(scale=1, min_width=280):
|
| 139 |
-
with gr.Accordion("Pretrained
|
| 140 |
pretrained_negative_prompt = gr.Textbox(
|
| 141 |
label="Negative prompt", value=PretrainedInferenceConfig.negative_prompt, lines=1
|
| 142 |
)
|
|
@@ -145,7 +162,7 @@ with gr.Blocks() as demo:
|
|
| 145 |
|
| 146 |
with gr.Column(scale=2):
|
| 147 |
pretrained_gallery = gr.Gallery(
|
| 148 |
-
label="Pretrained outputs", show_label=True, elem_id="pretrained_gallery", height="240px", columns=4,
|
| 149 |
object_fit="contain",
|
| 150 |
)
|
| 151 |
pretrained_info = gr.Textbox(label="Pretrained info", interactive=False)
|
|
@@ -192,13 +209,30 @@ with gr.Blocks() as demo:
|
|
| 192 |
object_fit="contain",
|
| 193 |
)
|
| 194 |
smc_grad_info = gr.Textbox(label="SMC-grad info", interactive=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
# --- Wiring ---
|
| 197 |
# 1) Quick 'running' update attached to the button so the UI shows immediate feedback.
|
| 198 |
run_button.click(
|
| 199 |
fn=mark_all_running,
|
| 200 |
inputs=[],
|
| 201 |
-
outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info],
|
| 202 |
)
|
| 203 |
|
| 204 |
# 2) Attach the per-method heavy functions separately. Gradio's queue() will allow
|
|
@@ -229,12 +263,18 @@ with gr.Blocks() as demo:
|
|
| 229 |
],
|
| 230 |
outputs=[smc_grad_gallery, smc_grad_info],
|
| 231 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
# Also allow pressing Enter in the prompt to trigger the same set of handlers
|
| 234 |
prompt.submit(
|
| 235 |
fn=mark_all_running,
|
| 236 |
inputs=[],
|
| 237 |
-
outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info],
|
| 238 |
)
|
| 239 |
prompt.submit(
|
| 240 |
fn=run_pretrained_ui,
|
|
@@ -261,6 +301,11 @@ with gr.Blocks() as demo:
|
|
| 261 |
],
|
| 262 |
outputs=[smc_grad_gallery, smc_grad_info],
|
| 263 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
# Enable Gradio queue to allow parallel execution of multiple handlers. Set concurrency
|
| 266 |
# to 2 (one per method) — increase if you add more methods.
|
|
|
|
| 7 |
from src.smc.inference import (
|
| 8 |
infer_pretrained,
|
| 9 |
infer_smc_grad,
|
| 10 |
+
infer_ft,
|
| 11 |
PretrainedInferenceConfig,
|
| 12 |
SMCGradInferenceConfig,
|
| 13 |
+
FTInferenceConfig,
|
| 14 |
)
|
| 15 |
|
| 16 |
def get_device():
|
|
|
|
| 47 |
|
| 48 |
# --- Per-method runner functions ---
|
| 49 |
def run_pretrained_ui(prompt, pretrained_negative_prompt, pretrained_CFG, pretrained_steps):
|
| 50 |
+
"""Run the pretrained inference method and return (gallery, info)."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
try:
|
| 52 |
pretrained_cfg = PretrainedInferenceConfig(
|
| 53 |
prompt=prompt,
|
|
|
|
| 108 |
traceback.print_exc()
|
| 109 |
err_msg = f"SMC-grad inference error: {e}"
|
| 110 |
return [err_msg], err_msg
|
| 111 |
+
|
| 112 |
+
def run_ft_ui(prompt, ft_negative_prompt, ft_CFG, ft_steps):
|
| 113 |
+
"""Run the finetuned model inference and return (gallery, info)."""
|
| 114 |
+
try:
|
| 115 |
+
ft_cfg = FTInferenceConfig(
|
| 116 |
+
prompt=prompt,
|
| 117 |
+
negative_prompt=ft_negative_prompt or "",
|
| 118 |
+
CFG=float(ft_CFG),
|
| 119 |
+
steps=int(ft_steps),
|
| 120 |
+
)
|
| 121 |
+
out = infer_ft(ft_cfg, device=get_device())
|
| 122 |
+
gallery = out.images
|
| 123 |
+
info = _format_inference_output(out)
|
| 124 |
+
return gallery, info
|
| 125 |
+
except Exception as e:
|
| 126 |
+
traceback.print_exc()
|
| 127 |
+
err_msg = f"FT inference error: {e}"
|
| 128 |
+
# Return a simple textual error in the gallery and the info box
|
| 129 |
+
return [err_msg], err_msg
|
| 130 |
|
| 131 |
|
| 132 |
def mark_all_running():
|
|
|
|
| 138 |
running_info = gr.update(value="Running...", interactive=False)
|
| 139 |
empty_gallery = gr.update(value=[])
|
| 140 |
# Return values must match the components this function is attached to (see below)
|
| 141 |
+
return empty_gallery, running_info, empty_gallery, running_info, empty_gallery, running_info
|
| 142 |
|
| 143 |
|
| 144 |
with gr.Blocks() as demo:
|
|
|
|
| 153 |
# --- Pretrained method row ---
|
| 154 |
with gr.Row():
|
| 155 |
with gr.Column(scale=1, min_width=280):
|
| 156 |
+
with gr.Accordion("Pretrained model — settings", open=False):
|
| 157 |
pretrained_negative_prompt = gr.Textbox(
|
| 158 |
label="Negative prompt", value=PretrainedInferenceConfig.negative_prompt, lines=1
|
| 159 |
)
|
|
|
|
| 162 |
|
| 163 |
with gr.Column(scale=2):
|
| 164 |
pretrained_gallery = gr.Gallery(
|
| 165 |
+
label="Pretrained model outputs", show_label=True, elem_id="pretrained_gallery", height="240px", columns=4,
|
| 166 |
object_fit="contain",
|
| 167 |
)
|
| 168 |
pretrained_info = gr.Textbox(label="Pretrained info", interactive=False)
|
|
|
|
| 209 |
object_fit="contain",
|
| 210 |
)
|
| 211 |
smc_grad_info = gr.Textbox(label="SMC-grad info", interactive=False)
|
| 212 |
+
|
| 213 |
+
# --- FT method row ---
|
| 214 |
+
with gr.Row():
|
| 215 |
+
with gr.Column(scale=1, min_width=280):
|
| 216 |
+
with gr.Accordion("Finetuned model — settings", open=False):
|
| 217 |
+
ft_negative_prompt = gr.Textbox(
|
| 218 |
+
label="Negative prompt", value=FTInferenceConfig.negative_prompt, lines=1
|
| 219 |
+
)
|
| 220 |
+
ft_CFG = gr.Slider(0.0, 30.0, step=0.1, value=FTInferenceConfig.CFG, label="CFG")
|
| 221 |
+
ft_steps = gr.Slider(1, 200, step=1, value=FTInferenceConfig.steps, label="Steps")
|
| 222 |
+
|
| 223 |
+
with gr.Column(scale=2):
|
| 224 |
+
ft_gallery = gr.Gallery(
|
| 225 |
+
label="Finetuned model outputs", show_label=True, elem_id="ft_gallery", height="240px", columns=4,
|
| 226 |
+
object_fit="contain",
|
| 227 |
+
)
|
| 228 |
+
ft_info = gr.Textbox(label="Finetuned info", interactive=False)
|
| 229 |
|
| 230 |
# --- Wiring ---
|
| 231 |
# 1) Quick 'running' update attached to the button so the UI shows immediate feedback.
|
| 232 |
run_button.click(
|
| 233 |
fn=mark_all_running,
|
| 234 |
inputs=[],
|
| 235 |
+
outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info, ft_gallery, ft_info],
|
| 236 |
)
|
| 237 |
|
| 238 |
# 2) Attach the per-method heavy functions separately. Gradio's queue() will allow
|
|
|
|
| 263 |
],
|
| 264 |
outputs=[smc_grad_gallery, smc_grad_info],
|
| 265 |
)
|
| 266 |
+
|
| 267 |
+
run_button.click(
|
| 268 |
+
fn=run_ft_ui,
|
| 269 |
+
inputs=[prompt, ft_negative_prompt, ft_CFG, ft_steps],
|
| 270 |
+
outputs=[ft_gallery, ft_info],
|
| 271 |
+
)
|
| 272 |
|
| 273 |
# Also allow pressing Enter in the prompt to trigger the same set of handlers
|
| 274 |
prompt.submit(
|
| 275 |
fn=mark_all_running,
|
| 276 |
inputs=[],
|
| 277 |
+
outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info, ft_gallery, ft_info],
|
| 278 |
)
|
| 279 |
prompt.submit(
|
| 280 |
fn=run_pretrained_ui,
|
|
|
|
| 301 |
],
|
| 302 |
outputs=[smc_grad_gallery, smc_grad_info],
|
| 303 |
)
|
| 304 |
+
prompt.submit(
|
| 305 |
+
fn=run_ft_ui,
|
| 306 |
+
inputs=[prompt, ft_negative_prompt, ft_CFG, ft_steps],
|
| 307 |
+
outputs=[ft_gallery, ft_info],
|
| 308 |
+
)
|
| 309 |
|
| 310 |
# Enable Gradio queue to allow parallel execution of multiple handlers. Set concurrency
|
| 311 |
# to 2 (one per method) — increase if you add more methods.
|
requirements.txt
CHANGED
|
@@ -7,4 +7,5 @@ xformers
|
|
| 7 |
gradio
|
| 8 |
spaces
|
| 9 |
image-reward
|
| 10 |
-
openai-clip
|
|
|
|
|
|
| 7 |
gradio
|
| 8 |
spaces
|
| 9 |
image-reward
|
| 10 |
+
openai-clip
|
| 11 |
+
peft
|
src/smc/inference.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import math
|
| 2 |
import threading
|
| 3 |
import spaces
|
|
@@ -25,6 +26,7 @@ MIN_GPU_DURATION = 60
|
|
| 25 |
pipe_build_lock = threading.Lock()
|
| 26 |
pipe_load_lock = threading.Lock()
|
| 27 |
reward_model_load_lock = threading.Lock()
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
def build_pipe(device):
|
|
@@ -43,6 +45,13 @@ def build_pipe(device):
|
|
| 43 |
pipe = Pipeline(vq_model, tokenizer=tokenizer, text_encoder=text_encoder, transformer=model, scheduler=scheduler_new)
|
| 44 |
return pipe
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
@dataclass
|
| 47 |
class InferenceOutput:
|
| 48 |
images: List[Image.Image]
|
|
@@ -205,3 +214,59 @@ def infer_smc_grad_with_pipe(config: SMCGradInferenceConfig, pipe: Pipeline, dev
|
|
| 205 |
pil_images: List[Image.Image] = pipe.image_processor.postprocess(images, "pil") # type: ignore
|
| 206 |
gpu_mem_used = torch.cuda.max_memory_allocated(device) / 1024**3
|
| 207 |
return InferenceOutput(images=pil_images, image_rewards=image_rewards, gpu_mem_used=gpu_mem_used)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
import math
|
| 3 |
import threading
|
| 4 |
import spaces
|
|
|
|
| 26 |
pipe_build_lock = threading.Lock()
|
| 27 |
pipe_load_lock = threading.Lock()
|
| 28 |
reward_model_load_lock = threading.Lock()
|
| 29 |
+
lora_load_lock = threading.Lock()
|
| 30 |
|
| 31 |
|
| 32 |
def build_pipe(device):
|
|
|
|
| 45 |
pipe = Pipeline(vq_model, tokenizer=tokenizer, text_encoder=text_encoder, transformer=model, scheduler=scheduler_new)
|
| 46 |
return pipe
|
| 47 |
|
| 48 |
+
def load_lora_weights(pipe, lora_ckpt_uuid):
|
| 49 |
+
# LORA lora checkpoint
|
| 50 |
+
ckpt_path = os.path.join('checkpoints', lora_ckpt_uuid)
|
| 51 |
+
pipe.load_lora_weights(
|
| 52 |
+
pretrained_model_name_or_path_or_dict=ckpt_path,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
@dataclass
|
| 56 |
class InferenceOutput:
|
| 57 |
images: List[Image.Image]
|
|
|
|
| 214 |
pil_images: List[Image.Image] = pipe.image_processor.postprocess(images, "pil") # type: ignore
|
| 215 |
gpu_mem_used = torch.cuda.max_memory_allocated(device) / 1024**3
|
| 216 |
return InferenceOutput(images=pil_images, image_rewards=image_rewards, gpu_mem_used=gpu_mem_used)
|
| 217 |
+
|
| 218 |
+
@dataclass
|
| 219 |
+
class FTInferenceConfig:
|
| 220 |
+
prompt: str
|
| 221 |
+
negative_prompt: str = "worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark"
|
| 222 |
+
resolution: int = 512
|
| 223 |
+
CFG: float = 9.0
|
| 224 |
+
steps: int = 48
|
| 225 |
+
num_batches: int = 4
|
| 226 |
+
ckpt_uuid: str = "a1e906e1-16a9-44a3-abe8-6dd2c17e12a2"
|
| 227 |
+
|
| 228 |
+
def infer_ft(config: FTInferenceConfig, device='cpu'):
|
| 229 |
+
with pipe_build_lock:
|
| 230 |
+
pipe = build_pipe(device)
|
| 231 |
+
return infer_ft_with_pipe(config, pipe, device=device)
|
| 232 |
+
|
| 233 |
+
def _get_ft_duration(config: FTInferenceConfig, pipe: Pipeline, device='cpu') -> int:
|
| 234 |
+
setup_duration = 30.0
|
| 235 |
+
step_duration = 1.0
|
| 236 |
+
total_duration = math.ceil(setup_duration + step_duration * config.steps)
|
| 237 |
+
return max(total_duration, MIN_GPU_DURATION)
|
| 238 |
+
|
| 239 |
+
@spaces.GPU(duration=_get_ft_duration)
|
| 240 |
+
def infer_ft_with_pipe(config: FTInferenceConfig, pipe: Pipeline, device='cpu'):
|
| 241 |
+
if isinstance(device, str):
|
| 242 |
+
device = torch.device(device)
|
| 243 |
+
with pipe_load_lock:
|
| 244 |
+
pipe = pipe.to(device)
|
| 245 |
+
with lora_load_lock:
|
| 246 |
+
load_lora_weights(pipe, config.ckpt_uuid)
|
| 247 |
+
reward_bias = 5.0
|
| 248 |
+
with reward_model_load_lock:
|
| 249 |
+
reward_fn, reward_name = rewards.ImageReward_Fk_Steering(device=device, bias=reward_bias), "image_reward_plus_5"
|
| 250 |
+
image_reward_fn = lambda images: reward_fn(
|
| 251 |
+
images,
|
| 252 |
+
[config.prompt] * len(images)
|
| 253 |
+
)
|
| 254 |
+
images = pipe(
|
| 255 |
+
prompt=config.prompt,
|
| 256 |
+
reward_fn=image_reward_fn,
|
| 257 |
+
resample_fn=lambda log_w: resample(log_w),
|
| 258 |
+
negative_prompt=config.negative_prompt,
|
| 259 |
+
height=config.resolution,
|
| 260 |
+
width=config.resolution,
|
| 261 |
+
guidance_scale=config.CFG,
|
| 262 |
+
num_inference_steps=config.steps,
|
| 263 |
+
batches=config.num_batches,
|
| 264 |
+
num_particles=1,
|
| 265 |
+
batch_p=config.num_batches,
|
| 266 |
+
proposal_type="without_SMC",
|
| 267 |
+
output_type="pt",
|
| 268 |
+
)
|
| 269 |
+
image_rewards = (image_reward_fn(images) - reward_bias).tolist()
|
| 270 |
+
pil_images: List[Image.Image] = pipe.image_processor.postprocess(images, "pil") # type: ignore
|
| 271 |
+
gpu_mem_used = torch.cuda.max_memory_allocated(device) / 1024**3
|
| 272 |
+
return InferenceOutput(images=pil_images, image_rewards=image_rewards, gpu_mem_used=gpu_mem_used)
|