import os os.environ["CUDA_LAUNCH_BLOCKING"] = "1" os.environ["TORCH_USE_CUDA_DSA"] = "1" import torch import traceback import gradio as gr # Import your inference functions and dataclasses # Adjust the import path if your file is located elsewhere from src.smc.inference import ( infer_pretrained, infer_smc_grad, infer_ft, PretrainedInferenceConfig, SMCGradInferenceConfig, FTInferenceConfig, ) from run_examples import get_out_if_exists GALLERY_HEIGHT = "224px" def get_device(): if not hasattr(get_device, "last_allocated"): get_device.last_allocated = -1 # type: ignore if not torch.cuda.is_available(): return "cuda" # GPU will be dynamically allocated later using spaces ZeroGPU # Round-robin allocation d = torch.cuda.device_count() i = (get_device.last_allocated + 1) % d # type: ignore get_device.last_allocated = i # type: ignore return f"cuda:{i}" examples = [ "A photo of a yellow bird and a black motorcycle", "A stylish dog wearing sunglasses", "A cat in the style of Van Gogh’s Starry Night", ] def _format_inference_output(out) -> str: """Return a short summary string for the UI""" if out is None: return "No output" try: rewards = out.image_rewards mem = out.gpu_mem_used return f"Rewards: {rewards} | GPU mem (GB): {mem:.3f}" except Exception: return "Could not parse inference output" def try_load_saved_outputs(prompt): """ Check for saved outputs for the given prompt for each method and return (pretrained_gallery, pretrained_info, smc_gallery, smc_info, ft_gallery, ft_info). If no saved output exists for a method, returns an empty gallery and \"No saved output\" for info for that method. """ try: # Pretrained pre_cfg = PretrainedInferenceConfig(prompt=prompt) pre_out = get_out_if_exists("pretrained", pre_cfg) if pre_out is not None: pre_gallery = pre_out.images pre_info = _format_inference_output(pre_out) else: pre_gallery, pre_info = [], "No saved output" # SMC-grad smc_cfg = SMCGradInferenceConfig(prompt=prompt) smc_out = get_out_if_exists("smc_grad", smc_cfg) if smc_out is not None: smc_gallery = smc_out.images smc_info = _format_inference_output(smc_out) else: smc_gallery, smc_info = [], "No saved output" # FT ft_cfg = FTInferenceConfig(prompt=prompt) ft_out = get_out_if_exists("ft", ft_cfg) if ft_out is not None: ft_gallery = ft_out.images ft_info = _format_inference_output(ft_out) else: ft_gallery, ft_info = [], "No saved output" return pre_gallery, pre_info, smc_gallery, smc_info, ft_gallery, ft_info except Exception as e: # Don't crash the UI; print the traceback and return empty placeholders traceback.print_exc() return [], "Error checking saved outputs", [], "Error checking saved outputs", [], "Error checking saved outputs" # --- Per-method runner functions --- def run_pretrained_ui(prompt, pretrained_negative_prompt, pretrained_CFG, pretrained_steps): """Run the pretrained inference method and return (gallery, info).""" try: pretrained_cfg = PretrainedInferenceConfig( prompt=prompt, negative_prompt=pretrained_negative_prompt or "", CFG=float(pretrained_CFG), steps=int(pretrained_steps), ) out = infer_pretrained(pretrained_cfg, device=get_device()) gallery = out.images info = _format_inference_output(out) return gallery, info except Exception as e: traceback.print_exc() err_msg = f"Pretrained inference error: {e}" # Return a simple textual error in the gallery and the info box return [err_msg], err_msg def run_smc_grad_ui( prompt, smc_grad_negative_prompt, smc_grad_CFG, smc_grad_steps, smc_grad_num_particles, smc_grad_ess_threshold, smc_grad_partial_resampling, smc_grad_resample_frequency, smc_grad_kl_weight, smc_grad_lambda_tempering, smc_grad_lambda_one_at, smc_grad_use_continuous_formulation, smc_grad_phi, smc_grad_tau, ): """Run the SMC-grad inference method and return (gallery, info).""" try: smc_grad_cfg = SMCGradInferenceConfig( prompt=prompt, negative_prompt=smc_grad_negative_prompt or "", ess_threshold=float(smc_grad_ess_threshold), partial_resampling=bool(smc_grad_partial_resampling), resample_frequency=int(smc_grad_resample_frequency), CFG=float(smc_grad_CFG), steps=int(smc_grad_steps), kl_weight=float(smc_grad_kl_weight), lambda_tempering=bool(smc_grad_lambda_tempering), lambda_one_at=float(smc_grad_lambda_one_at), num_particles=int(smc_grad_num_particles), use_continuous_formulation=bool(smc_grad_use_continuous_formulation), phi=int(smc_grad_phi), tau=float(smc_grad_tau), ) out = infer_smc_grad(smc_grad_cfg, device=get_device()) gallery = out.images info = _format_inference_output(out) return gallery, info except Exception as e: traceback.print_exc() err_msg = f"SMC-grad inference error: {e}" return [err_msg], err_msg def run_ft_ui(prompt, ft_negative_prompt, ft_CFG, ft_steps): """Run the finetuned model inference and return (gallery, info).""" try: ft_cfg = FTInferenceConfig( prompt=prompt, negative_prompt=ft_negative_prompt or "", CFG=float(ft_CFG), steps=int(ft_steps), ) out = infer_ft(ft_cfg, device=get_device()) gallery = out.images info = _format_inference_output(out) return gallery, info except Exception as e: traceback.print_exc() err_msg = f"FT inference error: {e}" # Return a simple textual error in the gallery and the info box return [err_msg], err_msg def mark_all_running(): """Quick lightweight callback to immediately mark UI components as running. This runs quickly and returns updates so the UI shows a "Running..." state while the heavy inference functions are queued/executed. """ running_info = gr.update(value="Running...", interactive=False) empty_gallery = gr.update(value=[]) # Return values must match the components this function is attached to (see below) return empty_gallery, running_info, empty_gallery, running_info, empty_gallery, running_info with gr.Blocks() as demo: gr.Markdown("# Prompt alignment for Meissonic using SMC") with gr.Row(): prompt = gr.Textbox(label="Prompt", placeholder="Enter prompt here", value=examples[0], lines=1) run_button = gr.Button("Run", variant="primary") examples_widget = gr.Examples(examples=examples, inputs=prompt) # --- Pretrained method row --- with gr.Row(): with gr.Column(scale=1, min_width=280): with gr.Accordion("Pretrained model — settings", open=False): pretrained_negative_prompt = gr.Textbox( label="Negative prompt", value=PretrainedInferenceConfig.negative_prompt, lines=1 ) pretrained_CFG = gr.Slider(0.0, 30.0, step=0.1, value=PretrainedInferenceConfig.CFG, label="CFG") pretrained_steps = gr.Slider(1, 200, step=1, value=PretrainedInferenceConfig.steps, label="Steps") with gr.Column(scale=2): pretrained_gallery = gr.Gallery( label="Pretrained model outputs", show_label=True, elem_id="pretrained_gallery", height=GALLERY_HEIGHT, columns=4, object_fit="contain", ) pretrained_info = gr.Textbox(label="Pretrained info", interactive=False, visible=False) # --- SMC-grad method row --- with gr.Row(): with gr.Column(scale=1, min_width=280): with gr.Accordion("SMC-grad method — settings", open=False): smc_grad_negative_prompt = gr.Textbox( label="Negative prompt", value=SMCGradInferenceConfig.negative_prompt, lines=1 ) smc_grad_CFG = gr.Slider(0.0, 30.0, step=0.1, value=SMCGradInferenceConfig.CFG, label="CFG") smc_grad_steps = gr.Slider(1, 200, step=1, value=SMCGradInferenceConfig.steps, label="Steps") smc_grad_num_particles = gr.Slider( 1, 64, step=1, value=SMCGradInferenceConfig.num_particles, label="SMC Num particles" ) smc_grad_ess_threshold = gr.Slider( 0.0, 1.0, step=0.01, value=SMCGradInferenceConfig.ess_threshold, label="ESS threshold" ) smc_grad_partial_resampling = gr.Checkbox( label="Partial resampling", value=SMCGradInferenceConfig.partial_resampling ) smc_grad_resample_frequency = gr.Slider( 1, 50, step=1, value=SMCGradInferenceConfig.resample_frequency, label="Resample frequency" ) smc_grad_kl_weight = gr.Slider( 0.0, 10.0, step=0.01, value=SMCGradInferenceConfig.kl_weight, label="KL weight" ) smc_grad_lambda_tempering = gr.Checkbox( label="Lambda tempering", value=SMCGradInferenceConfig.lambda_tempering ) smc_grad_lambda_one_at = gr.Slider( 0.0, 1.0, step=0.01, value=SMCGradInferenceConfig.lambda_one_at, label="Lambda one at (fraction of steps)" ) smc_grad_use_continuous_formulation = gr.Checkbox( label="Use continuous formulation", value=SMCGradInferenceConfig.use_continuous_formulation ) smc_grad_phi = gr.Slider(1, 8, step=1, value=SMCGradInferenceConfig.phi, label="Phi") smc_grad_tau = gr.Slider(0.0, 1.0, step=0.001, value=SMCGradInferenceConfig.tau, label="Tau") with gr.Column(scale=2): smc_grad_gallery = gr.Gallery( label="SMC-grad outputs", show_label=True, elem_id="smc_grad_gallery", height=GALLERY_HEIGHT, columns=4, object_fit="contain", ) smc_grad_info = gr.Textbox(label="SMC-grad info", interactive=False, visible=False) # --- FT method row --- with gr.Row(): with gr.Column(scale=1, min_width=280): with gr.Accordion("Finetuned model — settings", open=False): ft_negative_prompt = gr.Textbox( label="Negative prompt", value=FTInferenceConfig.negative_prompt, lines=1 ) ft_CFG = gr.Slider(0.0, 30.0, step=0.1, value=FTInferenceConfig.CFG, label="CFG") ft_steps = gr.Slider(1, 200, step=1, value=FTInferenceConfig.steps, label="Steps") with gr.Column(scale=2): ft_gallery = gr.Gallery( label="Finetuned model outputs", show_label=True, elem_id="ft_gallery", height=GALLERY_HEIGHT, columns=4, object_fit="contain", ) ft_info = gr.Textbox(label="Finetuned info", interactive=False, visible=False) # --- Wiring --- # 1) Quick 'running' update attached to the button so the UI shows immediate feedback. run_button.click( fn=mark_all_running, inputs=[], outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info, ft_gallery, ft_info], ) # 2) Attach the per-method heavy functions separately. Gradio's queue() will allow # them to execute concurrently and update their respective outputs as they complete. run_button.click( fn=run_pretrained_ui, inputs=[prompt, pretrained_negative_prompt, pretrained_CFG, pretrained_steps], outputs=[pretrained_gallery, pretrained_info], ) run_button.click( fn=run_smc_grad_ui, inputs=[ prompt, smc_grad_negative_prompt, smc_grad_CFG, smc_grad_steps, smc_grad_num_particles, smc_grad_ess_threshold, smc_grad_partial_resampling, smc_grad_resample_frequency, smc_grad_kl_weight, smc_grad_lambda_tempering, smc_grad_lambda_one_at, smc_grad_use_continuous_formulation, smc_grad_phi, smc_grad_tau, ], outputs=[smc_grad_gallery, smc_grad_info], ) run_button.click( fn=run_ft_ui, inputs=[prompt, ft_negative_prompt, ft_CFG, ft_steps], outputs=[ft_gallery, ft_info], ) # Also allow pressing Enter in the prompt to trigger the same set of handlers prompt.submit( fn=mark_all_running, inputs=[], outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info, ft_gallery, ft_info], ) prompt.submit( fn=run_pretrained_ui, inputs=[prompt, pretrained_negative_prompt, pretrained_CFG, pretrained_steps], outputs=[pretrained_gallery, pretrained_info], ) prompt.submit( fn=run_smc_grad_ui, inputs=[ prompt, smc_grad_negative_prompt, smc_grad_CFG, smc_grad_steps, smc_grad_num_particles, smc_grad_ess_threshold, smc_grad_partial_resampling, smc_grad_resample_frequency, smc_grad_kl_weight, smc_grad_lambda_tempering, smc_grad_lambda_one_at, smc_grad_use_continuous_formulation, smc_grad_phi, smc_grad_tau, ], outputs=[smc_grad_gallery, smc_grad_info], ) prompt.submit( fn=run_ft_ui, inputs=[prompt, ft_negative_prompt, ft_CFG, ft_steps], outputs=[ft_gallery, ft_info], ) # Trigger when an example is selected examples_widget.load_input_event.then( fn=try_load_saved_outputs, inputs=[prompt], outputs=[ pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info, ft_gallery, ft_info, ], ) # Trigger once on page load for the initial prompt value (so example[0] loads on startup) demo.load( fn=try_load_saved_outputs, inputs=[prompt], outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info, ft_gallery, ft_info], ) # Enable Gradio queue to allow parallel execution of multiple handlers. Set concurrency # to 2 (one per method) — increase if you add more methods. # You can fine-tune max_size / concurrency_count for your deployment. # Important: call queue() before launch() demo.queue(default_concurrency_limit=3) if __name__ == "__main__": demo.launch(share=True)