import spaces import gradio as gr import os import sys from typing import List # sys.path.append(os.getcwd()) import numpy as np from PIL import Image import torch print(f'torch version:{torch.__version__}') # import subprocess # import importlib, site, sys # # Re-discover all .pth/.egg-link files # for sitedir in site.getsitepackages(): # site.addsitedir(sitedir) # # Clear caches so importlib will pick up new modules # importlib.invalidate_caches() # def sh(cmd): subprocess.check_call(cmd, shell=True) # sh("pip install -U xformers --index-url https://download.pytorch.org/whl/cu126") # # tell Python to re-scan site-packages now that the egg-link exists # import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches() import torch.utils.checkpoint from pytorch_lightning import seed_everything from diffusers import AutoencoderKL, DDIMScheduler from diffusers.utils import check_min_version from diffusers.utils.import_utils import is_xformers_available from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor from huggingface_hub import hf_hub_download, snapshot_download from pipelines.pipeline_seesr import StableDiffusionControlNetPipeline from utils.wavelet_color_fix import wavelet_color_fix, adain_color_fix from ram.models.ram_lora import ram from ram import inference_ram as inference from torchvision import transforms from models.controlnet import ControlNetModel from models.unet_2d_condition import UNet2DConditionModel tensor_transforms = transforms.Compose([ transforms.ToTensor(), ]) ram_transforms = transforms.Compose([ transforms.Resize((384, 384)), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) snapshot_download( repo_id="alexnasa/SEESR", local_dir="preset/models" ) snapshot_download( repo_id="stabilityai/sd-turbo", local_dir="preset/models/sd-turbo" ) snapshot_download( repo_id="xinyu1205/recognize_anything_model", local_dir="preset/models/" ) # Load scheduler, tokenizer and models. pretrained_model_path = 'preset/models/sd-turbo' seesr_model_path = 'preset/models/seesr' scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") # feature_extractor = CLIPImageProcessor.from_pretrained(f"{pretrained_model_path}/feature_extractor") unet = UNet2DConditionModel.from_pretrained_orig(pretrained_model_path, seesr_model_path, subfolder="unet") controlnet = ControlNetModel.from_pretrained(seesr_model_path, subfolder="controlnet") # Freeze vae and text_encoder vae.requires_grad_(False) text_encoder.requires_grad_(False) unet.requires_grad_(False) controlnet.requires_grad_(False) # unet.to("cuda") # controlnet.to("cuda") # unet.enable_xformers_memory_efficient_attention() # controlnet.enable_xformers_memory_efficient_attention() # Get the validation pipeline validation_pipeline = StableDiffusionControlNetPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=None, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False, ) validation_pipeline._init_tiled_vae(encoder_tile_size=1024, decoder_tile_size=224) weight_dtype = torch.float16 device = "cuda" # Move text_encode and vae to gpu and cast to weight_dtype text_encoder.to(device, dtype=weight_dtype) vae.to(device, dtype=weight_dtype) unet.to(device, dtype=weight_dtype) controlnet.to(device, dtype=weight_dtype) tag_model = ram(pretrained='preset/models/ram_swin_large_14m.pth', pretrained_condition='preset/models/DAPE.pth', image_size=384, vit='swin_l') tag_model.eval() tag_model.to(device, dtype=weight_dtype) @spaces.GPU() def process( input_image: Image.Image, user_prompt: str, use_KDS: bool, bandwidth: float, num_particles: int, positive_prompt: str, negative_prompt: str, num_inference_steps: int, scale_factor: int, cfg_scale: float, seed: int, latent_tiled_size: int, latent_tiled_overlap: int, sample_times: int ) -> List[np.ndarray]: process_size = 512 resize_preproc = transforms.Compose([ transforms.Resize(process_size, interpolation=transforms.InterpolationMode.BILINEAR), ]) # with torch.no_grad(): seed_everything(seed) generator = torch.Generator(device=device) validation_prompt = "" lq = tensor_transforms(input_image).unsqueeze(0).to(device).half() lq = ram_transforms(lq) res = inference(lq, tag_model) ram_encoder_hidden_states = tag_model.generate_image_embeds(lq) validation_prompt = f"{res[0]}, {positive_prompt}," validation_prompt = validation_prompt if user_prompt=='' else f"{user_prompt}, {validation_prompt}" ori_width, ori_height = input_image.size resize_flag = False rscale = scale_factor input_image = input_image.resize((int(input_image.size[0] * rscale), int(input_image.size[1] * rscale))) if min(input_image.size) < process_size: input_image = resize_preproc(input_image) input_image = input_image.resize((input_image.size[0] // 8 * 8, input_image.size[1] // 8 * 8)) width, height = input_image.size resize_flag = True # images = [] for _ in range(sample_times): try: with torch.autocast("cuda"): image = validation_pipeline( validation_prompt, input_image, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, generator=generator, height=height, width=width, guidance_scale=cfg_scale, conditioning_scale=1, start_point='lr', start_steps=999,ram_encoder_hidden_states=ram_encoder_hidden_states, latent_tiled_size=latent_tiled_size, latent_tiled_overlap=latent_tiled_overlap, use_KDS=use_KDS, bandwidth=bandwidth, num_particles=num_particles ).images[0] if True: # alpha<1.0: image = wavelet_color_fix(image, input_image) if resize_flag: image = image.resize((ori_width * rscale, ori_height * rscale)) except Exception as e: print(e) image = Image.new(mode="RGB", size=(512, 512)) images.append(np.array(image)) return images # MARKDOWN = \ """ ## SeeSR: Towards Semantics-Aware Real-World Image Super-Resolution [GitHub](https://github.com/cswry/SeeSR) | [Paper](https://arxiv.org/abs/2311.16518) If SeeSR is helpful for you, please help star the GitHub Repo. Thanks! """ block = gr.Blocks().queue() with block: with gr.Row(): gr.Markdown(MARKDOWN) with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil") num_particles = gr.Slider(label="Num of Partickes", minimum=1, maximum=16, step=1, value=10) bandwidth = gr.Slider(label="Bandwidth", minimum=0.1, maximum=0.8, step=0.1, value=0.1) use_KDS = gr.Checkbox(label="Use Kernel Density Steering") run_button = gr.Button("Run") with gr.Accordion("Options", open=True): user_prompt = gr.Textbox(label="User Prompt", value="") positive_prompt = gr.Textbox(label="Positive Prompt", value="clean, high-resolution, 8k, best quality, masterpiece") negative_prompt = gr.Textbox( label="Negative Prompt", value="dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality" ) cfg_scale = gr.Slider(label="Classifier Free Guidance Scale (Set to 1.0 in sd-turbo)", minimum=1, maximum=10, value=7.5, step=0) num_inference_steps = gr.Slider(label="Inference Steps", minimum=2, maximum=100, value=50, step=1) seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=231) sample_times = gr.Slider(label="Sample Times", minimum=1, maximum=10, step=1, value=1) latent_tiled_size = gr.Slider(label="Diffusion Tile Size", minimum=128, maximum=480, value=320, step=1) latent_tiled_overlap = gr.Slider(label="Diffusion Tile Overlap", minimum=4, maximum=16, value=4, step=1) scale_factor = gr.Number(label="SR Scale", value=4) with gr.Column(): result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery") examples = gr.Examples( examples=[ [ "preset/datasets/test_datasets/woman.png", "", False, 0.1, 4, "clean, high-resolution, 8k, best quality, masterpiece", "dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 4, 4, 1.0, 123, 320, 4, 1, ], [ "preset/datasets/test_datasets/woman.png", "", True, 0.1, 4, "clean, high-resolution, 8k, best quality, masterpiece", "dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 4, 4, 1.0, 123, 320, 4, 1, ], [ "preset/datasets/test_datasets/woman.png", "", True, 0.1, 16, "clean, high-resolution, 8k, best quality, masterpiece", "dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 4, 4, 1.0, 123, 320, 4, 1, ], ], inputs=[ input_image, user_prompt, use_KDS, bandwidth, num_particles, positive_prompt, negative_prompt, num_inference_steps, scale_factor, cfg_scale, seed, latent_tiled_size, latent_tiled_overlap, sample_times, ], outputs=[result_gallery], fn=process, cache_examples=True, ) inputs = [ input_image, user_prompt, use_KDS, bandwidth, num_particles, positive_prompt, negative_prompt, num_inference_steps, scale_factor, cfg_scale, seed, latent_tiled_size, latent_tiled_overlap, sample_times, ] run_button.click(fn=process, inputs=inputs, outputs=[result_gallery]) block.launch(share=True)