import gradio as gr from huggingface_hub import snapshot_download from models import UNet import torch import os import torch.nn.functional as F from PIL import Image import numpy as np # Define your model repository name model_repo = "GreyCC99127/DTGM" # Download the model files model_dir = snapshot_download(repo_id=model_repo) model_path = os.path.join(model_dir, "DTGM_model_167500.pt") # Load the model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class DTLS(torch.nn.Module): def __init__(self, m): super().__init__() self.UNet = m def forward(self, x): with torch.no_grad(): t = timestep blur_img = transform_func_sample(x.clone(), size_list[t]) img_t = blur_img.clone() ####### Domain Transfer while t: next_step = size_list[t - 1] step = torch.full((1,), t, dtype=torch.long).to(device) R_x = self.UNet(img_t, step) img_t = transform_func_noise(R_x, next_step) t -= 1 return R_x model = DTLS(UNet()) data = torch.load(model_path, map_location=device) model.load_state_dict(data['ema'], strict=False) del data model.to(device) image_size = 256 size_list = [256, 64, 32, 16, 8, 4, 3, 2] timestep = len(size_list) - 1 def transform_func_sample(img, target_size): n = target_size m = image_size if m / n > 16: img_1 = F.interpolate(img, size=m // 4, mode='bicubic', antialias=True) img_1 = F.interpolate(img_1, size=m // 8, mode='bicubic', antialias=True) img_1 = F.interpolate(img_1, size=n, mode='bicubic', antialias=True) else: img_1 = F.interpolate(img, size=n, mode='bicubic', antialias=True) img_1 = F.interpolate(img_1, size=m, mode='bicubic', antialias=True) return img_1 def transform_func_noise(img, target_size): n = target_size m = image_size random_mean = torch.rand(1).mul(0.1).add(-.05).item() decreasing_scale = 0.9 ** (n - 2) if m / n > 16: img_1 = F.interpolate(img, size=m // 4, mode='bicubic', antialias=True) img_1 = F.interpolate(img_1, size=m // 8, mode='bicubic', antialias=True) img_1 = F.interpolate(img_1, size=n, mode='bicubic', antialias=True) else: img_1 = F.interpolate(img, size=n, mode='bicubic', antialias=True) noise = torch.normal(mean=random_mean, std=0.5, size=(img_1.shape[0], 3, 2, 2)).to(device) noise = F.interpolate(noise, size=n, mode='bicubic', antialias=True) img_1 += noise * decreasing_scale img_1 = F.interpolate(img_1, size=m, mode='bicubic', antialias=True) if n >= 16: noise_refinement = torch.normal(mean=0, std=1, size=img_1.shape).to(device) img_1 = img_1 + noise_refinement * decreasing_scale return img_1 def tensor_to_pil(image_tensor): """Convert a PyTorch tensor to a PIL image.""" # Ensure the tensor is on the CPU (if on GPU) image_tensor = image_tensor.detach().cpu() # Normalize and clamp values (if needed) # Example for images in [-1, 1] range (common in GANs): image_tensor = (image_tensor.clamp(-1, 1) + 1) / 2.0 # Scale to [0, 1] # Convert to PIL Image # Assuming CHW format (e.g., (3, 256, 256)) np_image = image_tensor.numpy().transpose(1, 2, 0) # CHW → HWC np_image = (np_image * 255).astype(np.uint8) # Scale to [0, 255] return Image.fromarray(np_image) def generate_initial_image(x, y): return torch.normal(x, y, size=(1,3,2,2)).to(device) # Gradio Interface def app(mean, std): # Generate initial image initial_img = generate_initial_image(mean, std) # Generate final image final_img = model(initial_img) initial_img = tensor_to_pil(F.interpolate(initial_img, size=256, mode='nearest-exact').squeeze(0)) final_img = tensor_to_pil(final_img.squeeze(0)) return [ initial_img, # First output (left) final_img # Second output (right) ] css = """ /* Make all sliders and labels larger */ .slider-container { margin: 0 auto !important; width: 50% !important; } .slider-container label { font-size: 24px !important; font-weight: bold !important; } .minimalist-slider input[type=range] { height: 8px !important; background: #e0e0e0 !important; border-radius: 50x !important; } .minimalist-slider input[type=range]::-webkit-slider-thumb { width: 20px !important; height: 20px !important; background: #4a90e2 !important; border: none !important; border-radius: 50% !important; box-shadow: 0 2px 4px rgba(0,0,0,0.2) !important; } .minimalist-slider label { font-size: 8px !important; color: #333 !important; margin-bottom: 8px !important; } input[type=range] { height: 20px !important; width: 100% !important; } input[type=range]::-webkit-slider-thumb { width: 20px !important; height: 20px !important; } /* Style the generate button */ button { padding: 12px 24px !important; font-size: 18px !important; margin: 20px auto !important; display: block !important; min-width: 200px !important; } /* Center the input section */ #input-section { text-align: center !important; margin: 0 auto !important; width: 100% !important; } /* Style the output images */ #output-images { display: flex !important; justify-content: space-around !important; margin-top: 20px !important; } .output-image { width: 45% !important; text-align: center !important; } .output-image label { font-size: 18px !important; font-weight: bold !important; margin-bottom: 10px !important; } """ with gr.Blocks(css=css, title="DTGM Demo") as demo: gr.Markdown("# DTGM Demo") gr.Markdown("Input two values to generate initial and final images") with gr.Column(elem_id="input-section"): mean_slider = gr.Slider(-0.75, 0.75, label="Choose the mean value (-0.75 to 0.75)", value=0, elem_classes="minimalist-slider") std_slider = gr.Slider(0.01, 0.5, label="Choose the std (0.01 to 0.5)", value=0.25, elem_classes="minimalist-slider") generate_btn = gr.Button("Generate", variant="primary") with gr.Row(elem_id="output-images"): with gr.Column(elem_classes="output-image"): initial_out = gr.Image(label="Initial_Image", interactive=False) with gr.Column(elem_classes="output-image"): final_out = gr.Image(label="Final_Image", interactive=False) # Connect the button click to your function generate_btn.click( fn=app, inputs=[mean_slider, std_slider], outputs=[initial_out, final_out] ) demo.launch(inbrowser=True)