DTGM_demo / app.py
GreyCC99127's picture
Update app.py
d36e7fb verified
import gradio as gr
from huggingface_hub import snapshot_download
from models import UNet
import torch
import os
import torch.nn.functional as F
from PIL import Image
import numpy as np
# Define your model repository name
model_repo = "GreyCC99127/DTGM"
# Download the model files
model_dir = snapshot_download(repo_id=model_repo)
model_path = os.path.join(model_dir, "DTGM_model_167500.pt")
# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class DTLS(torch.nn.Module):
def __init__(self, m):
super().__init__()
self.UNet = m
def forward(self, x):
with torch.no_grad():
t = timestep
blur_img = transform_func_sample(x.clone(), size_list[t])
img_t = blur_img.clone()
####### Domain Transfer
while t:
next_step = size_list[t - 1]
step = torch.full((1,), t, dtype=torch.long).to(device)
R_x = self.UNet(img_t, step)
img_t = transform_func_noise(R_x, next_step)
t -= 1
return R_x
model = DTLS(UNet())
data = torch.load(model_path, map_location=device)
model.load_state_dict(data['ema'], strict=False)
del data
model.to(device)
image_size = 256
size_list = [256, 64, 32, 16, 8, 4, 3, 2]
timestep = len(size_list) - 1
def transform_func_sample(img, target_size):
n = target_size
m = image_size
if m / n > 16:
img_1 = F.interpolate(img, size=m // 4, mode='bicubic', antialias=True)
img_1 = F.interpolate(img_1, size=m // 8, mode='bicubic', antialias=True)
img_1 = F.interpolate(img_1, size=n, mode='bicubic', antialias=True)
else:
img_1 = F.interpolate(img, size=n, mode='bicubic', antialias=True)
img_1 = F.interpolate(img_1, size=m, mode='bicubic', antialias=True)
return img_1
def transform_func_noise(img, target_size):
n = target_size
m = image_size
random_mean = torch.rand(1).mul(0.1).add(-.05).item()
decreasing_scale = 0.9 ** (n - 2)
if m / n > 16:
img_1 = F.interpolate(img, size=m // 4, mode='bicubic', antialias=True)
img_1 = F.interpolate(img_1, size=m // 8, mode='bicubic', antialias=True)
img_1 = F.interpolate(img_1, size=n, mode='bicubic', antialias=True)
else:
img_1 = F.interpolate(img, size=n, mode='bicubic', antialias=True)
noise = torch.normal(mean=random_mean, std=0.5, size=(img_1.shape[0], 3, 2, 2)).to(device)
noise = F.interpolate(noise, size=n, mode='bicubic', antialias=True)
img_1 += noise * decreasing_scale
img_1 = F.interpolate(img_1, size=m, mode='bicubic', antialias=True)
if n >= 16:
noise_refinement = torch.normal(mean=0, std=1, size=img_1.shape).to(device)
img_1 = img_1 + noise_refinement * decreasing_scale
return img_1
def tensor_to_pil(image_tensor):
"""Convert a PyTorch tensor to a PIL image."""
# Ensure the tensor is on the CPU (if on GPU)
image_tensor = image_tensor.detach().cpu()
# Normalize and clamp values (if needed)
# Example for images in [-1, 1] range (common in GANs):
image_tensor = (image_tensor.clamp(-1, 1) + 1) / 2.0 # Scale to [0, 1]
# Convert to PIL Image
# Assuming CHW format (e.g., (3, 256, 256))
np_image = image_tensor.numpy().transpose(1, 2, 0) # CHW → HWC
np_image = (np_image * 255).astype(np.uint8) # Scale to [0, 255]
return Image.fromarray(np_image)
def generate_initial_image(x, y):
return torch.normal(x, y, size=(1,3,2,2)).to(device)
# Gradio Interface
def app(mean, std):
# Generate initial image
initial_img = generate_initial_image(mean, std)
# Generate final image
final_img = model(initial_img)
initial_img = tensor_to_pil(F.interpolate(initial_img, size=256, mode='nearest-exact').squeeze(0))
final_img = tensor_to_pil(final_img.squeeze(0))
return [
initial_img, # First output (left)
final_img # Second output (right)
]
css = """
/* Make all sliders and labels larger */
.slider-container {
margin: 0 auto !important;
width: 50% !important;
}
.slider-container label {
font-size: 24px !important;
font-weight: bold !important;
}
.minimalist-slider input[type=range] {
height: 8px !important;
background: #e0e0e0 !important;
border-radius: 50x !important;
}
.minimalist-slider input[type=range]::-webkit-slider-thumb {
width: 20px !important;
height: 20px !important;
background: #4a90e2 !important;
border: none !important;
border-radius: 50% !important;
box-shadow: 0 2px 4px rgba(0,0,0,0.2) !important;
}
.minimalist-slider label {
font-size: 8px !important;
color: #333 !important;
margin-bottom: 8px !important;
}
input[type=range] {
height: 20px !important;
width: 100% !important;
}
input[type=range]::-webkit-slider-thumb {
width: 20px !important;
height: 20px !important;
}
/* Style the generate button */
button {
padding: 12px 24px !important;
font-size: 18px !important;
margin: 20px auto !important;
display: block !important;
min-width: 200px !important;
}
/* Center the input section */
#input-section {
text-align: center !important;
margin: 0 auto !important;
width: 100% !important;
}
/* Style the output images */
#output-images {
display: flex !important;
justify-content: space-around !important;
margin-top: 20px !important;
}
.output-image {
width: 45% !important;
text-align: center !important;
}
.output-image label {
font-size: 18px !important;
font-weight: bold !important;
margin-bottom: 10px !important;
}
"""
with gr.Blocks(css=css, title="DTGM Demo") as demo:
gr.Markdown("# DTGM Demo")
gr.Markdown("Input two values to generate initial and final images")
with gr.Column(elem_id="input-section"):
mean_slider = gr.Slider(-0.75, 0.75, label="Choose the mean value (-0.75 to 0.75)", value=0, elem_classes="minimalist-slider")
std_slider = gr.Slider(0.01, 0.5, label="Choose the std (0.01 to 0.5)", value=0.25, elem_classes="minimalist-slider")
generate_btn = gr.Button("Generate", variant="primary")
with gr.Row(elem_id="output-images"):
with gr.Column(elem_classes="output-image"):
initial_out = gr.Image(label="Initial_Image", interactive=False)
with gr.Column(elem_classes="output-image"):
final_out = gr.Image(label="Final_Image", interactive=False)
# Connect the button click to your function
generate_btn.click(
fn=app,
inputs=[mean_slider, std_slider],
outputs=[initial_out, final_out]
)
demo.launch(inbrowser=True)