|
import torch |
|
import numpy as np |
|
from PIL import Image, ImageDraw, ImageFont |
|
import gradio as gr |
|
from diffusers import DiffusionPipeline |
|
from huggingface_hub import hf_hub_download |
|
import os |
|
|
|
|
|
use_custom_weights = True |
|
custom_weights_path = hf_hub_download( |
|
repo_id="focuzz/depth-estimation", |
|
filename="unet_weights.pth" |
|
) |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
|
|
|
pipe = DiffusionPipeline.from_pretrained( |
|
"prs-eth/marigold-v1-0", |
|
custom_pipeline="marigold_depth_estimation", |
|
torch_dtype=dtype |
|
).to(device) |
|
|
|
|
|
if use_custom_weights: |
|
state_dict = torch.load(custom_weights_path, map_location=device) |
|
prefix = "unet.conv_in." if any(k.startswith("unet.conv_in.") for k in state_dict) else "conv_in." |
|
conv_in_dict = { |
|
k.replace(prefix, ""): v |
|
for k, v in state_dict.items() |
|
if k.startswith(prefix) |
|
} |
|
pipe.unet.conv_in.load_state_dict(conv_in_dict) |
|
print("Загружены дообученные веса conv_in из:", custom_weights_path) |
|
|
|
|
|
def add_overlay(image: Image.Image, label: str) -> Image.Image: |
|
image = image.copy() |
|
draw = ImageDraw.Draw(image) |
|
try: |
|
font = ImageFont.load_default() |
|
except: |
|
font = None |
|
draw.text((10, 10), label, fill="white", font=font) |
|
return image |
|
|
|
|
|
TARGET_SIZE = (768, 768) |
|
def normalize_depth(depth_np): |
|
d = np.copy(depth_np) |
|
d_min = np.percentile(d, 1) |
|
d_max = np.percentile(d, 99) |
|
d = np.clip((d - d_min) / (d_max - d_min), 0, 1) |
|
return (d * 255).astype(np.uint8) |
|
|
|
def generate_gallery(): |
|
example_files = ["example1.jpg", "example2.jpg", "example3.jpg", "example4.jpg"] |
|
rgbs = [] |
|
depths_gray = [] |
|
depths_color = [] |
|
|
|
for path in example_files: |
|
if not os.path.exists(path): |
|
continue |
|
|
|
rgb = Image.open(path).convert("RGB").resize(TARGET_SIZE) |
|
|
|
with torch.no_grad(): |
|
output = pipe( |
|
rgb, |
|
denoising_steps=4, |
|
ensemble_size=5, |
|
processing_res=768, |
|
match_input_res=True, |
|
batch_size=0, |
|
color_map="Spectral", |
|
show_progress_bar=False, |
|
) |
|
|
|
depth_np = output.depth_np |
|
gray_normalized = normalize_depth(depth_np) |
|
depth_gray = Image.fromarray(gray_normalized).convert("RGB").resize(TARGET_SIZE, Image.BILINEAR) |
|
depth_color = output.depth_colored.resize(TARGET_SIZE, Image.BILINEAR) |
|
|
|
rgbs.append(add_overlay(rgb, "RGB")) |
|
depths_gray.append(add_overlay(depth_gray, "Глубина (серая)")) |
|
depths_color.append(add_overlay(depth_color, "Глубина (цветная)")) |
|
|
|
return rgbs + depths_color + depths_gray |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Генерация карт глубины") |
|
gr.Markdown( |
|
"Модель основана на Marigold (ETH), дообучена на indoor-сценах из NYUv2. " |
|
"Сохраняет способность обрабатывать произвольные изображения благодаря наличию оригинальных U-Net весов." |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
input_image = gr.Image(label="Загрузите RGB изображение", type="pil") |
|
denoise = gr.Slider(1, 50, value=4, step=1, label="Шаги денойзинга") |
|
ensemble = gr.Slider(1, 10, value=5, step=1, label="Размер ансамбля (количество запусков для одной картинки)") |
|
resolution = gr.Slider(256, 1024, value=768, step=64, label="Разрешение обработки изображений") |
|
match_res = gr.Checkbox(value=True, label="Сохранять исходное разрешение") |
|
with gr.Column(scale=1): |
|
output_image = gr.Image(label="Карта глубины") |
|
|
|
def predict_depth(image, denoising_steps, ensemble_size, processing_res, match_input_res): |
|
with torch.no_grad(): |
|
output = pipe( |
|
image, |
|
denoising_steps=denoising_steps, |
|
ensemble_size=ensemble_size, |
|
processing_res=processing_res, |
|
match_input_res=match_input_res, |
|
batch_size=0, |
|
color_map="Spectral", |
|
show_progress_bar=False, |
|
) |
|
return output.depth_colored |
|
|
|
submit_btn = gr.Button("Выполнить предсказание") |
|
submit_btn.click( |
|
predict_depth, |
|
inputs=[input_image, denoise, ensemble, resolution, match_res], |
|
outputs=output_image |
|
) |
|
|
|
gr.Markdown("### Примеры:") |
|
gallery = gr.Gallery(label="Сравнение RGB и Глубины", columns=4) |
|
demo.load(fn=generate_gallery, outputs=gallery) |
|
|
|
demo.launch(ssr_mode=False) |