focuzz commited on
Commit
36ab632
·
verified ·
1 Parent(s): 573f528

Upload 6 files

Browse files
Files changed (7) hide show
  1. .gitattributes +4 -0
  2. app.py +133 -0
  3. example1.jpg +3 -0
  4. example2.jpg +3 -0
  5. example3.jpg +3 -0
  6. example4.jpg +3 -0
  7. requirements.txt +8 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example1.jpg filter=lfs diff=lfs merge=lfs -text
37
+ example2.jpg filter=lfs diff=lfs merge=lfs -text
38
+ example3.jpg filter=lfs diff=lfs merge=lfs -text
39
+ example4.jpg filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image, ImageDraw, ImageFont
4
+ import gradio as gr
5
+ from diffusers import DiffusionPipeline
6
+ from huggingface_hub import hf_hub_download
7
+ import os
8
+
9
+ # Настройки
10
+ use_custom_weights = True
11
+ custom_weights_path = hf_hub_download(
12
+ repo_id="focuzz/depth-estimation",
13
+ filename="unet_weights.pth"
14
+ )
15
+
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ dtype = torch.float16 if device == "cuda" else torch.float32
18
+
19
+ # Загрузка пайплайна
20
+ pipe = DiffusionPipeline.from_pretrained(
21
+ "prs-eth/marigold-v1-0",
22
+ custom_pipeline="marigold_depth_estimation",
23
+ torch_dtype=dtype
24
+ ).to(device)
25
+
26
+ # Загрузка дообученных весов
27
+ if use_custom_weights:
28
+ state_dict = torch.load(custom_weights_path, map_location=device)
29
+ prefix = "unet.conv_in." if any(k.startswith("unet.conv_in.") for k in state_dict) else "conv_in."
30
+ conv_in_dict = {
31
+ k.replace(prefix, ""): v
32
+ for k, v in state_dict.items()
33
+ if k.startswith(prefix)
34
+ }
35
+ pipe.unet.conv_in.load_state_dict(conv_in_dict)
36
+ print("Загружены дообученные веса conv_in из:", custom_weights_path)
37
+
38
+ # Добавление overlay-текста
39
+ def add_overlay(image: Image.Image, label: str) -> Image.Image:
40
+ image = image.copy()
41
+ draw = ImageDraw.Draw(image)
42
+ font = ImageFont.load_default()
43
+ draw.text((10, 10), label, fill="white", font=font)
44
+ return image
45
+
46
+ # Генерация галереи из примеров
47
+ TARGET_SIZE = (768, 768)
48
+ def normalize_depth(depth_np):
49
+ d = np.copy(depth_np)
50
+ d_min = np.percentile(d, 1)
51
+ d_max = np.percentile(d, 99)
52
+ d = np.clip((d - d_min) / (d_max - d_min), 0, 1)
53
+ return (d * 255).astype(np.uint8)
54
+
55
+ def generate_gallery():
56
+ example_files = ["example1.jpg", "example2.jpg", "example3.jpg", "example4.jpg"]
57
+ rgbs = []
58
+ depths_gray = []
59
+ depths_color = []
60
+
61
+ for path in example_files:
62
+ if not os.path.exists(path):
63
+ continue
64
+
65
+ rgb = Image.open(path).convert("RGB").resize(TARGET_SIZE)
66
+
67
+ with torch.no_grad():
68
+ output = pipe(
69
+ rgb,
70
+ denoising_steps=4,
71
+ ensemble_size=5,
72
+ processing_res=768,
73
+ match_input_res=True,
74
+ batch_size=0,
75
+ color_map="Spectral",
76
+ show_progress_bar=False,
77
+ )
78
+
79
+ depth_np = output.depth_np
80
+ gray_normalized = normalize_depth(depth_np)
81
+ depth_gray = Image.fromarray(gray_normalized).convert("RGB").resize(TARGET_SIZE, Image.BILINEAR)
82
+ depth_color = output.depth_colored.resize(TARGET_SIZE, Image.BILINEAR)
83
+
84
+ rgbs.append(add_overlay(rgb, "RGB"))
85
+ depths_gray.append(add_overlay(depth_gray, "Depth (gray)"))
86
+ depths_color.append(add_overlay(depth_color, "Depth (color)"))
87
+
88
+ return rgbs + depths_color + depths_gray
89
+
90
+ # Интерфейс Blocks с галереей и инференсом
91
+ with gr.Blocks() as demo:
92
+ gr.Markdown("## Генерация карт глубины")
93
+ gr.Markdown(
94
+ "Модель основана на Marigold (ETH), дообучена на indoor-сценах из NYUv2. "
95
+ "Сохраняет способность обрабатывать произвольные изображения благодаря наличию оригинальных U-Net весов."
96
+ )
97
+
98
+ with gr.Row():
99
+ with gr.Column(scale=1):
100
+ input_image = gr.Image(label="Загрузите RGB изображение", type="pil")
101
+ denoise = gr.Slider(1, 50, value=4, step=1, label="Denoising Steps")
102
+ ensemble = gr.Slider(1, 10, value=5, step=1, label="Ensemble Size")
103
+ resolution = gr.Slider(256, 1024, value=768, step=64, label="Processing Resolution")
104
+ match_res = gr.Checkbox(value=True, label="Match Input Resolution")
105
+ with gr.Column(scale=1):
106
+ output_image = gr.Image(label="Карта глубины")
107
+
108
+ def predict_depth(image, denoising_steps, ensemble_size, processing_res, match_input_res):
109
+ with torch.no_grad():
110
+ output = pipe(
111
+ image,
112
+ denoising_steps=denoising_steps,
113
+ ensemble_size=ensemble_size,
114
+ processing_res=processing_res,
115
+ match_input_res=match_input_res,
116
+ batch_size=0,
117
+ color_map="Spectral",
118
+ show_progress_bar=False,
119
+ )
120
+ return output.depth_colored
121
+
122
+ submit_btn = gr.Button("Выполнить предсказание")
123
+ submit_btn.click(
124
+ predict_depth,
125
+ inputs=[input_image, denoise, ensemble, resolution, match_res],
126
+ outputs=output_image
127
+ )
128
+
129
+ gr.Markdown("### Примеры:")
130
+ gallery = gr.Gallery(label="Сравнение RGB и Depth", columns=4)
131
+ demo.load(fn=generate_gallery, outputs=gallery)
132
+
133
+ demo.launch(share=True)
example1.jpg ADDED

Git LFS Details

  • SHA256: d1dab6e854df0842ac0cabc6df32edf9cc3ceea1d3de4f3ce8dc998464dd99a2
  • Pointer size: 131 Bytes
  • Size of remote file: 720 kB
example2.jpg ADDED

Git LFS Details

  • SHA256: 16ca201a9a374a167b5c8e6b7a98918a3b168af3dcb371452207c12a5050f66a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.48 MB
example3.jpg ADDED

Git LFS Details

  • SHA256: c06a80983dddd372d551bf9f27539ab83930602df176bfa039e0e92a98f9413a
  • Pointer size: 131 Bytes
  • Size of remote file: 316 kB
example4.jpg ADDED

Git LFS Details

  • SHA256: 248f0838c5763af9c606ce992de77fa2c8a192a2ee41dfb0e6539af58deccdef
  • Pointer size: 131 Bytes
  • Size of remote file: 119 kB
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ diffusers>=0.25.0
3
+ transformers
4
+ accelerate
5
+ gradio
6
+ huggingface_hub
7
+ matplotlib
8
+ scipy