Men1scus
fix: Improve error handling in process_sr function to raise error with message during inference failure
3ff2374
import gradio as gr | |
import spaces | |
from typing import List | |
import argparse | |
import sys | |
import os | |
import glob | |
sys.path.append(os.getcwd()) | |
from llava.llm_agent import LLavaAgent | |
from PIL import Image | |
# from CKPT_PTH import LLAVA_MODEL_PATH | |
import re | |
import numpy as np | |
from PIL import Image | |
import torch | |
from pytorch_lightning import seed_everything | |
from diffusers import ( | |
AutoencoderKL, | |
FlowMatchEulerDiscreteScheduler, | |
) | |
from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast | |
from pipelines.pipeline_dit4sr import StableDiffusion3ControlNetPipeline | |
from utils.wavelet_color_fix import adain_color_fix | |
from torchvision import transforms | |
from model_dit4sr.transformer_sd3 import SD3Transformer2DModel | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--pretrained_model_name_or_path", type=str, default='stabilityai/stable-diffusion-3.5-medium') | |
parser.add_argument("--load_both_models", type=str, default='True') # whether to load both dit4sr_q and dit4sr_f models | |
parser.add_argument("--transformer_model_name_or_path", type=str, default='acceptee/DiT4SR') | |
parser.add_argument("--mixed_precision", type=str, default="fp16") # no/fp16/bf16 | |
parser.add_argument("--process_size", type=int, default=512) | |
parser.add_argument("--vae_decoder_tiled_size", type=int, default=224) # latent size, for 24G | |
parser.add_argument("--vae_encoder_tiled_size", type=int, default=1024) # image size, for 13G | |
parser.add_argument("--latent_tiled_size", type=int, default=64) | |
parser.add_argument("--latent_tiled_overlap", type=int, default=16) | |
parser.add_argument("--start_point", type=str, choices=['lr', 'noise'], default='noise') # LR Embedding Strategy, choose 'lr latent + 999 steps noise' as diffusion start point. | |
parser.add_argument( | |
"--revision", | |
type=str, | |
default=None, | |
required=False, | |
help="Revision of pretrained model identifier from huggingface.co/models.", | |
) | |
parser.add_argument( | |
"--variant", | |
type=str, | |
default=None, | |
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", | |
) | |
args = parser.parse_args() | |
# Copied from dreambooth sd3 example | |
def import_model_class_from_model_name_or_path( | |
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" | |
): | |
text_encoder_config = PretrainedConfig.from_pretrained( | |
pretrained_model_name_or_path, subfolder=subfolder, revision=revision | |
) | |
model_class = text_encoder_config.architectures[0] | |
if model_class == "CLIPTextModelWithProjection": | |
from transformers import CLIPTextModelWithProjection | |
return CLIPTextModelWithProjection | |
elif model_class == "T5EncoderModel": | |
from transformers import T5EncoderModel | |
return T5EncoderModel | |
else: | |
raise ValueError(f"{model_class} is not supported.") | |
# Copied from dreambooth sd3 example | |
def load_text_encoders(class_one, class_two, class_three, args): | |
text_encoder_one = class_one.from_pretrained( | |
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant | |
) | |
text_encoder_two = class_two.from_pretrained( | |
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant | |
) | |
text_encoder_three = class_three.from_pretrained( | |
args.pretrained_model_name_or_path, subfolder="text_encoder_3", revision=args.revision, variant=args.variant | |
) | |
return text_encoder_one, text_encoder_two, text_encoder_three | |
def load_dit4sr_q_pipeline(args, device): | |
# Load scheduler, tokenizer and models. | |
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
args.pretrained_model_name_or_path, subfolder="scheduler" | |
) | |
vae = AutoencoderKL.from_pretrained( | |
args.pretrained_model_name_or_path, | |
subfolder="vae", | |
) | |
transformer = SD3Transformer2DModel.from_pretrained( | |
args.transformer_model_name_or_path, subfolder="dit4sr_q/transformer" | |
) | |
# controlnet = SD3ControlNetModel.from_pretrained(args.controlnet_model_name_or_path, subfolder='controlnet') | |
# Load the tokenizer | |
tokenizer_one = CLIPTokenizer.from_pretrained( | |
args.pretrained_model_name_or_path, | |
subfolder="tokenizer", | |
revision=args.revision, | |
) | |
tokenizer_two = CLIPTokenizer.from_pretrained( | |
args.pretrained_model_name_or_path, | |
subfolder="tokenizer_2", | |
revision=args.revision, | |
) | |
tokenizer_three = T5TokenizerFast.from_pretrained( | |
args.pretrained_model_name_or_path, | |
subfolder="tokenizer_3", | |
revision=args.revision, | |
) | |
# import correct text encoder class | |
text_encoder_cls_one = import_model_class_from_model_name_or_path( | |
args.pretrained_model_name_or_path, args.revision | |
) | |
text_encoder_cls_two = import_model_class_from_model_name_or_path( | |
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" | |
) | |
text_encoder_cls_three = import_model_class_from_model_name_or_path( | |
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3" | |
) | |
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( | |
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three, args | |
) | |
# Freeze vae and text_encoder | |
vae.requires_grad_(False) | |
text_encoder_one.requires_grad_(False) | |
text_encoder_two.requires_grad_(False) | |
text_encoder_three.requires_grad_(False) | |
transformer.requires_grad_(False) | |
# Get the validation pipeline | |
validation_pipeline = StableDiffusion3ControlNetPipeline( | |
vae=vae, text_encoder=text_encoder_one, text_encoder_2=text_encoder_two, text_encoder_3=text_encoder_three, | |
tokenizer=tokenizer_one, tokenizer_2=tokenizer_two, tokenizer_3=tokenizer_three, | |
transformer=transformer, scheduler=scheduler, | |
) | |
# For mixed precision training we cast the text_encoder and vae weights to half-precision | |
# as these models are only used for inference, keeping weights in full precision is not required. | |
weight_dtype = torch.float32 | |
if args.mixed_precision == "fp16": | |
weight_dtype = torch.float16 | |
elif args.mixed_precision == "bf16": | |
weight_dtype = torch.bfloat16 | |
# Move text_encode and vae to gpu and cast to weight_dtype | |
text_encoder_one.to(device, dtype=weight_dtype) | |
text_encoder_two.to(device, dtype=weight_dtype) | |
text_encoder_three.to(device, dtype=weight_dtype) | |
vae.to(device, dtype=weight_dtype) | |
transformer.to(device, dtype=weight_dtype) | |
return validation_pipeline | |
def load_dit4sr_f_pipeline(args, device): | |
# Load scheduler, tokenizer and models. | |
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
args.pretrained_model_name_or_path, subfolder="scheduler" | |
) | |
vae = AutoencoderKL.from_pretrained( | |
args.pretrained_model_name_or_path, | |
subfolder="vae", | |
) | |
transformer = SD3Transformer2DModel.from_pretrained( | |
args.transformer_model_name_or_path, subfolder="dit4sr_f/transformer" | |
) | |
# controlnet = SD3ControlNetModel.from_pretrained(args.controlnet_model_name_or_path, subfolder='controlnet') | |
# Load the tokenizer | |
tokenizer_one = CLIPTokenizer.from_pretrained( | |
args.pretrained_model_name_or_path, | |
subfolder="tokenizer", | |
revision=args.revision, | |
) | |
tokenizer_two = CLIPTokenizer.from_pretrained( | |
args.pretrained_model_name_or_path, | |
subfolder="tokenizer_2", | |
revision=args.revision, | |
) | |
tokenizer_three = T5TokenizerFast.from_pretrained( | |
args.pretrained_model_name_or_path, | |
subfolder="tokenizer_3", | |
revision=args.revision, | |
) | |
# import correct text encoder class | |
text_encoder_cls_one = import_model_class_from_model_name_or_path( | |
args.pretrained_model_name_or_path, args.revision | |
) | |
text_encoder_cls_two = import_model_class_from_model_name_or_path( | |
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" | |
) | |
text_encoder_cls_three = import_model_class_from_model_name_or_path( | |
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3" | |
) | |
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( | |
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three, args | |
) | |
# Freeze vae and text_encoder | |
vae.requires_grad_(False) | |
text_encoder_one.requires_grad_(False) | |
text_encoder_two.requires_grad_(False) | |
text_encoder_three.requires_grad_(False) | |
transformer.requires_grad_(False) | |
# Get the validation pipeline | |
validation_pipeline = StableDiffusion3ControlNetPipeline( | |
vae=vae, text_encoder=text_encoder_one, text_encoder_2=text_encoder_two, text_encoder_3=text_encoder_three, | |
tokenizer=tokenizer_one, tokenizer_2=tokenizer_two, tokenizer_3=tokenizer_three, | |
transformer=transformer, scheduler=scheduler, | |
) | |
# For mixed precision training we cast the text_encoder and vae weights to half-precision | |
# as these models are only used for inference, keeping weights in full precision is not required. | |
weight_dtype = torch.float32 | |
if args.mixed_precision == "fp16": | |
weight_dtype = torch.float16 | |
elif args.mixed_precision == "bf16": | |
weight_dtype = torch.bfloat16 | |
# Move text_encode and vae to gpu and cast to weight_dtype | |
text_encoder_one.to(device, dtype=weight_dtype) | |
text_encoder_two.to(device, dtype=weight_dtype) | |
text_encoder_three.to(device, dtype=weight_dtype) | |
vae.to(device, dtype=weight_dtype) | |
transformer.to(device, dtype=weight_dtype) | |
return validation_pipeline | |
def remove_focus_sentences(text): | |
# 使用正则表达式按照 . ? ! 分割,并且保留分隔符本身 | |
# re.split(pattern, string) 如果 pattern 中带有捕获组(),分隔符也会保留在结果列表中 | |
prohibited_words = ['focus', 'focal', 'prominent', 'close-up', 'black and white', 'blur', 'depth', 'dense', 'locate', 'position'] | |
parts = re.split(r'([.?!])', text) | |
filtered_sentences = [] | |
i = 0 | |
while i < len(parts): | |
# sentence 可能是句子主体,punctuation 是该句子结尾的标点 | |
sentence = parts[i] | |
punctuation = parts[i+1] if (i+1 < len(parts)) else '' | |
# 组合为完整句子,避免漏掉结尾标点 | |
full_sentence = sentence + punctuation | |
full_sentence_lower = full_sentence.lower() | |
skip = False | |
for word in prohibited_words: | |
if word.lower() in full_sentence_lower: | |
skip = True | |
break | |
# 如果该句子不包含任何禁用词,则保留 | |
if not skip: | |
filtered_sentences.append(full_sentence) | |
# 跳过已经处理的句子和标点 | |
i += 2 | |
# 根据需要选择如何重新拼接;这里去掉多余空格并直接拼接 | |
return "".join(filtered_sentences).strip() | |
# if torch.cuda.device_count() >= 2: | |
# LLaVA_device = 'cuda:0' | |
# dit4sr_device = 'cuda:1' | |
# elif torch.cuda.device_count() == 1: | |
# LLaVA_device = 'cuda:0' | |
# dit4sr_device = 'cuda:0' | |
# else: | |
# raise ValueError('Currently support CUDA only.') | |
LLaVA_device = 'cuda:0' | |
dit4sr_device = 'cuda:0' | |
llava_agent = LLavaAgent("liuhaotian/llava-v1.5-13b", LLaVA_device, load_8bit=True, load_4bit=False) | |
# Get the validation pipeline - prioritize dit4sr_f | |
pipeline_dit4sr_f = load_dit4sr_f_pipeline(args, dit4sr_device) | |
# Only load dit4sr_q if load_both_models is True | |
pipeline_dit4sr_q = None | |
if args.load_both_models == 'True': | |
pipeline_dit4sr_q = load_dit4sr_q_pipeline(args, dit4sr_device) | |
def process_llava( | |
input_image): | |
llama_prompt = llava_agent.gen_image_caption([input_image])[0] | |
llama_prompt = remove_focus_sentences(llama_prompt) | |
return llama_prompt | |
def process_sr( | |
input_image: Image.Image, | |
user_prompt: str, | |
positive_prompt: str, | |
negative_prompt: str, | |
num_inference_steps: int, | |
scale_factor: int, | |
cfg_scale: float, | |
seed: int, | |
model_choice: str, | |
) -> Image.Image: | |
process_size = 512 | |
resize_preproc = transforms.Compose([ | |
transforms.Resize(process_size, interpolation=transforms.InterpolationMode.BILINEAR), | |
]) | |
if input_image.mode != 'RGB': | |
input_image = input_image.convert('RGB') | |
seed_everything(seed) | |
generator = torch.Generator(device=dit4sr_device) | |
generator.manual_seed(seed) | |
validation_prompt = f"{user_prompt} {positive_prompt}" | |
ori_width, ori_height = input_image.size | |
resize_flag = False | |
rscale = scale_factor | |
input_image = input_image.resize((int(input_image.size[0] * rscale), int(input_image.size[1] * rscale))) | |
if min(input_image.size) < process_size: | |
input_image = resize_preproc(input_image) | |
input_image = input_image.resize((input_image.size[0] // 8 * 8, input_image.size[1] // 8 * 8)) | |
width, height = input_image.size | |
resize_flag = True # | |
# Choose pipeline based on model selection - prioritize dit4sr_f | |
if model_choice == "dit4sr_q" and pipeline_dit4sr_q is not None: | |
pipeline = pipeline_dit4sr_q | |
else: | |
pipeline = pipeline_dit4sr_f | |
weight_dtype = torch.float32 | |
if args.mixed_precision == "fp16": | |
weight_dtype = torch.float16 | |
elif args.mixed_precision == "bf16": | |
weight_dtype = torch.bfloat16 | |
try: | |
with torch.autocast(device_type='cuda', dtype=weight_dtype, enabled=(args.mixed_precision != "no")): | |
image = pipeline( | |
prompt=validation_prompt, control_image=input_image, num_inference_steps=num_inference_steps, generator=generator, height=height, width=width, | |
guidance_scale=cfg_scale, negative_prompt=negative_prompt, start_point=args.start_point, latent_tiled_size=args.latent_tiled_size, latent_tiled_overlap=args.latent_tiled_overlap, | |
args=args, | |
).images[0] | |
if True: # alpha<1.0: | |
image = adain_color_fix(image, input_image) | |
if resize_flag: | |
image = image.resize((ori_width * rscale, ori_height * rscale)) | |
except Exception as e: | |
print(f"Error during inference: {e}") | |
image = Image.new(mode="RGB", size=(512, 512)) | |
raise gr.Error(f"Error during inference: {e}", duration=None) | |
return image | |
Intro = \ | |
""" | |
## DiT4SR: Taming Diffusion Transformer for Real-World Image Super-Resolution | |
[🕸️ Project Page](https://adam-duan.github.io/projects/dit4sr) • [📄 Paper](https://arxiv.org/abs/2503.23580) • [💻 Code](https://github.com/Adam-duan/DiT4SR) • [📦 Model](https://huggingface.co/acceptee/DiT4SR) • [📊 Dataset](https://huggingface.co/datasets/acceptee/NKUSR8K) | |
""" | |
# Generate prompt text based on model availability | |
if args.load_both_models == 'True': | |
Prompt = \ | |
""" | |
First, select your preferred model (fidelity first or quality first). \\ | |
Then, click \"Run LLAVA\" to generate an initial prompt based on the input image. \\ | |
Modify the prompt for higher accuracy if needed. \\ | |
Finally, click \"Run DiT4SR\" to generate the SR result." \ | |
""" | |
else: | |
Prompt = \ | |
""" | |
Click \"Run LLAVA\" to generate an initial prompt based on the input image. \\ | |
Modify the prompt for higher accuracy if needed. \\ | |
Finally, click \"Run DiT4SR\" to generate the SR result using fidelity first model." \ | |
""" | |
exaple_images = sorted(glob.glob('examples/*.png')) | |
block = gr.Blocks().queue() | |
with block: | |
with gr.Row(): | |
gr.Markdown(Intro) | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="pil") | |
user_prompt = gr.Textbox(label="User Prompt", value="") | |
# Only show model selection if load_both_models is True | |
if args.load_both_models == 'True': | |
model_choice = gr.Dropdown( | |
label="Model Selection", | |
choices=[("Quality First", "dit4sr_q"), ("Fidelity First", "dit4sr_f")], | |
value="dit4sr_f", | |
info="Choose between Quality First and Fidelity First models" | |
) | |
else: | |
# Hidden component with default value when only one model is available | |
model_choice = gr.Dropdown( | |
label="Model Selection", | |
choices=["dit4sr_f"], | |
value="dit4sr_f", | |
visible=False | |
) | |
with gr.Accordion("Options", open=False): | |
positive_prompt = gr.Textbox(label="Positive Prompt", value='Cinematic, perfect without deformations, ultra HD, ' | |
'camera, detailed photo, realistic maximum, 32k, Color.') | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value='motion blur, noisy, dotted, pointed, deformed, lowres, chaotic' | |
'CG Style, 3D render, unreal engine, blurring, dirty, messy, ' | |
'worst quality, low quality, watermark, signature, jpeg artifacts. ' | |
) | |
cfg_scale = gr.Slider(label="Classifier Free Guidance Scale (Set a value larger than 1 to enable it!)", minimum=0.1, maximum=10.0, value=7.0, step=0.1) | |
num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, value=20, step=1) | |
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=0) | |
scale_factor = gr.Number(label="SR Scale", value=4) | |
gr.Examples(examples=exaple_images, inputs=[input_image]) | |
with gr.Column(): | |
result_gallery = gr.Image(label="Output", show_label=False, elem_id="gallery", type="pil", format="png") | |
with gr.Row(): | |
run_llava_button = gr.Button(value="Run LLAVA") | |
run_sr_button = gr.Button(value="Run DiT4SR") | |
gr.Markdown(Prompt) | |
inputs = [ | |
input_image, | |
user_prompt, | |
positive_prompt, | |
negative_prompt, | |
num_inference_steps, | |
scale_factor, | |
cfg_scale, | |
seed, | |
model_choice, | |
] | |
run_llava_button.click(fn=process_llava, inputs=[input_image], outputs=[user_prompt]) | |
run_sr_button.click(fn=process_sr, inputs=inputs, outputs=[result_gallery]) | |
block.launch() | |