Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import snapshot_download | |
from models import UNet | |
import torch | |
import os | |
import torch.nn.functional as F | |
from PIL import Image | |
import numpy as np | |
# Define your model repository name | |
model_repo = "GreyCC99127/DTGM" | |
# Download the model files | |
model_dir = snapshot_download(repo_id=model_repo) | |
model_path = os.path.join(model_dir, "DTGM_model_167500.pt") | |
# Load the model | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class DTLS(torch.nn.Module): | |
def __init__(self, m): | |
super().__init__() | |
self.UNet = m | |
def forward(self, x): | |
with torch.no_grad(): | |
t = timestep | |
blur_img = transform_func_sample(x.clone(), size_list[t]) | |
img_t = blur_img.clone() | |
####### Domain Transfer | |
while t: | |
next_step = size_list[t - 1] | |
step = torch.full((1,), t, dtype=torch.long).to(device) | |
R_x = self.UNet(img_t, step) | |
img_t = transform_func_noise(R_x, next_step) | |
t -= 1 | |
return R_x | |
model = DTLS(UNet()) | |
data = torch.load(model_path, map_location=device) | |
model.load_state_dict(data['ema'], strict=False) | |
del data | |
model.to(device) | |
image_size = 256 | |
size_list = [256, 64, 32, 16, 8, 4, 3, 2] | |
timestep = len(size_list) - 1 | |
def transform_func_sample(img, target_size): | |
n = target_size | |
m = image_size | |
if m / n > 16: | |
img_1 = F.interpolate(img, size=m // 4, mode='bicubic', antialias=True) | |
img_1 = F.interpolate(img_1, size=m // 8, mode='bicubic', antialias=True) | |
img_1 = F.interpolate(img_1, size=n, mode='bicubic', antialias=True) | |
else: | |
img_1 = F.interpolate(img, size=n, mode='bicubic', antialias=True) | |
img_1 = F.interpolate(img_1, size=m, mode='bicubic', antialias=True) | |
return img_1 | |
def transform_func_noise(img, target_size): | |
n = target_size | |
m = image_size | |
random_mean = torch.rand(1).mul(0.1).add(-.05).item() | |
decreasing_scale = 0.9 ** (n - 2) | |
if m / n > 16: | |
img_1 = F.interpolate(img, size=m // 4, mode='bicubic', antialias=True) | |
img_1 = F.interpolate(img_1, size=m // 8, mode='bicubic', antialias=True) | |
img_1 = F.interpolate(img_1, size=n, mode='bicubic', antialias=True) | |
else: | |
img_1 = F.interpolate(img, size=n, mode='bicubic', antialias=True) | |
noise = torch.normal(mean=random_mean, std=0.5, size=(img_1.shape[0], 3, 2, 2)).to(device) | |
noise = F.interpolate(noise, size=n, mode='bicubic', antialias=True) | |
img_1 += noise * decreasing_scale | |
img_1 = F.interpolate(img_1, size=m, mode='bicubic', antialias=True) | |
if n >= 16: | |
noise_refinement = torch.normal(mean=0, std=1, size=img_1.shape).to(device) | |
img_1 = img_1 + noise_refinement * decreasing_scale | |
return img_1 | |
def tensor_to_pil(image_tensor): | |
"""Convert a PyTorch tensor to a PIL image.""" | |
# Ensure the tensor is on the CPU (if on GPU) | |
image_tensor = image_tensor.detach().cpu() | |
# Normalize and clamp values (if needed) | |
# Example for images in [-1, 1] range (common in GANs): | |
image_tensor = (image_tensor.clamp(-1, 1) + 1) / 2.0 # Scale to [0, 1] | |
# Convert to PIL Image | |
# Assuming CHW format (e.g., (3, 256, 256)) | |
np_image = image_tensor.numpy().transpose(1, 2, 0) # CHW → HWC | |
np_image = (np_image * 255).astype(np.uint8) # Scale to [0, 255] | |
return Image.fromarray(np_image) | |
def generate_initial_image(x, y): | |
return torch.normal(x, y, size=(1,3,2,2)).to(device) | |
# Gradio Interface | |
def app(mean, std): | |
# Generate initial image | |
initial_img = generate_initial_image(mean, std) | |
# Generate final image | |
final_img = model(initial_img) | |
initial_img = tensor_to_pil(F.interpolate(initial_img, size=256, mode='nearest-exact').squeeze(0)) | |
final_img = tensor_to_pil(final_img.squeeze(0)) | |
return [ | |
initial_img, # First output (left) | |
final_img # Second output (right) | |
] | |
css = """ | |
/* Make all sliders and labels larger */ | |
.slider-container { | |
margin: 0 auto !important; | |
width: 50% !important; | |
} | |
.slider-container label { | |
font-size: 24px !important; | |
font-weight: bold !important; | |
} | |
.minimalist-slider input[type=range] { | |
height: 8px !important; | |
background: #e0e0e0 !important; | |
border-radius: 50x !important; | |
} | |
.minimalist-slider input[type=range]::-webkit-slider-thumb { | |
width: 20px !important; | |
height: 20px !important; | |
background: #4a90e2 !important; | |
border: none !important; | |
border-radius: 50% !important; | |
box-shadow: 0 2px 4px rgba(0,0,0,0.2) !important; | |
} | |
.minimalist-slider label { | |
font-size: 8px !important; | |
color: #333 !important; | |
margin-bottom: 8px !important; | |
} | |
input[type=range] { | |
height: 20px !important; | |
width: 100% !important; | |
} | |
input[type=range]::-webkit-slider-thumb { | |
width: 20px !important; | |
height: 20px !important; | |
} | |
/* Style the generate button */ | |
button { | |
padding: 12px 24px !important; | |
font-size: 18px !important; | |
margin: 20px auto !important; | |
display: block !important; | |
min-width: 200px !important; | |
} | |
/* Center the input section */ | |
#input-section { | |
text-align: center !important; | |
margin: 0 auto !important; | |
width: 100% !important; | |
} | |
/* Style the output images */ | |
#output-images { | |
display: flex !important; | |
justify-content: space-around !important; | |
margin-top: 20px !important; | |
} | |
.output-image { | |
width: 45% !important; | |
text-align: center !important; | |
} | |
.output-image label { | |
font-size: 18px !important; | |
font-weight: bold !important; | |
margin-bottom: 10px !important; | |
} | |
""" | |
with gr.Blocks(css=css, title="DTGM Demo") as demo: | |
gr.Markdown("# DTGM Demo") | |
gr.Markdown("Input two values to generate initial and final images") | |
with gr.Column(elem_id="input-section"): | |
mean_slider = gr.Slider(-0.75, 0.75, label="Choose the mean value (-0.75 to 0.75)", value=0, elem_classes="minimalist-slider") | |
std_slider = gr.Slider(0.01, 0.5, label="Choose the std (0.01 to 0.5)", value=0.25, elem_classes="minimalist-slider") | |
generate_btn = gr.Button("Generate", variant="primary") | |
with gr.Row(elem_id="output-images"): | |
with gr.Column(elem_classes="output-image"): | |
initial_out = gr.Image(label="Initial_Image", interactive=False) | |
with gr.Column(elem_classes="output-image"): | |
final_out = gr.Image(label="Final_Image", interactive=False) | |
# Connect the button click to your function | |
generate_btn.click( | |
fn=app, | |
inputs=[mean_slider, std_slider], | |
outputs=[initial_out, final_out] | |
) | |
demo.launch(inbrowser=True) | |