Spaces:
Running
on
Zero
Running
on
Zero
Fix device placement
Browse files
marble.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import os
|
2 |
-
from typing import Dict
|
3 |
|
4 |
import numpy as np
|
5 |
import torch
|
@@ -31,10 +31,18 @@ def get_session():
|
|
31 |
return _session_cache
|
32 |
|
33 |
|
|
|
|
|
|
|
|
|
34 |
def setup_control_mlps(
|
35 |
-
features: int = 1024,
|
|
|
|
|
36 |
) -> Dict[str, torch.nn.Module]:
|
37 |
ret = {}
|
|
|
|
|
38 |
for mlp in CONTROL_MLPS:
|
39 |
ret[mlp] = setup_control_mlp(mlp, features, device, dtype)
|
40 |
return ret
|
@@ -43,9 +51,12 @@ def setup_control_mlps(
|
|
43 |
def setup_control_mlp(
|
44 |
material_parameter: str,
|
45 |
features: int = 1024,
|
46 |
-
device: str =
|
47 |
dtype: torch.dtype = torch.float16,
|
48 |
):
|
|
|
|
|
|
|
49 |
net = control_mlp(features)
|
50 |
net.load_state_dict(
|
51 |
torch.load(os.path.join(file_dir, f"model_weights/{material_parameter}.pt"))
|
@@ -95,9 +106,12 @@ def download_ip_adapter():
|
|
95 |
|
96 |
|
97 |
def setup_pipeline(
|
98 |
-
device: str =
|
99 |
dtype: torch.dtype = torch.float16,
|
100 |
):
|
|
|
|
|
|
|
101 |
download_ip_adapter()
|
102 |
|
103 |
cur_block = ("up", 0, 1)
|
@@ -135,7 +149,10 @@ def setup_pipeline(
|
|
135 |
)
|
136 |
|
137 |
|
138 |
-
def get_dpt_model(device: str =
|
|
|
|
|
|
|
139 |
image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
|
140 |
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas")
|
141 |
model.to(device, dtype=dtype)
|
@@ -144,9 +161,12 @@ def get_dpt_model(device: str = "cuda", dtype: torch.dtype = torch.float16):
|
|
144 |
|
145 |
|
146 |
def run_dpt_depth(
|
147 |
-
image: Image.Image, model, processor, device: str =
|
148 |
) -> Image.Image:
|
149 |
"""Run DPT depth estimation on an image."""
|
|
|
|
|
|
|
150 |
# Prepare image
|
151 |
inputs = processor(images=image, return_tensors="pt").to(device, dtype=model.dtype)
|
152 |
|
|
|
1 |
import os
|
2 |
+
from typing import Dict, Optional
|
3 |
|
4 |
import numpy as np
|
5 |
import torch
|
|
|
31 |
return _session_cache
|
32 |
|
33 |
|
34 |
+
def get_device():
|
35 |
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
36 |
+
|
37 |
+
|
38 |
def setup_control_mlps(
|
39 |
+
features: int = 1024,
|
40 |
+
device: Optional[str] = None,
|
41 |
+
dtype: torch.dtype = torch.float16,
|
42 |
) -> Dict[str, torch.nn.Module]:
|
43 |
ret = {}
|
44 |
+
if device is None:
|
45 |
+
device = get_device()
|
46 |
for mlp in CONTROL_MLPS:
|
47 |
ret[mlp] = setup_control_mlp(mlp, features, device, dtype)
|
48 |
return ret
|
|
|
51 |
def setup_control_mlp(
|
52 |
material_parameter: str,
|
53 |
features: int = 1024,
|
54 |
+
device: Optional[str] = None,
|
55 |
dtype: torch.dtype = torch.float16,
|
56 |
):
|
57 |
+
if device is None:
|
58 |
+
device = get_device()
|
59 |
+
|
60 |
net = control_mlp(features)
|
61 |
net.load_state_dict(
|
62 |
torch.load(os.path.join(file_dir, f"model_weights/{material_parameter}.pt"))
|
|
|
106 |
|
107 |
|
108 |
def setup_pipeline(
|
109 |
+
device: Optional[str] = None,
|
110 |
dtype: torch.dtype = torch.float16,
|
111 |
):
|
112 |
+
if device is None:
|
113 |
+
device = get_device()
|
114 |
+
|
115 |
download_ip_adapter()
|
116 |
|
117 |
cur_block = ("up", 0, 1)
|
|
|
149 |
)
|
150 |
|
151 |
|
152 |
+
def get_dpt_model(device: Optional[str] = None, dtype: torch.dtype = torch.float16):
|
153 |
+
if device is None:
|
154 |
+
device = get_device()
|
155 |
+
|
156 |
image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
|
157 |
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas")
|
158 |
model.to(device, dtype=dtype)
|
|
|
161 |
|
162 |
|
163 |
def run_dpt_depth(
|
164 |
+
image: Image.Image, model, processor, device: Optional[str] = None
|
165 |
) -> Image.Image:
|
166 |
"""Run DPT depth estimation on an image."""
|
167 |
+
if device is None:
|
168 |
+
device = get_device()
|
169 |
+
|
170 |
# Prepare image
|
171 |
inputs = processor(images=image, return_tensors="pt").to(device, dtype=model.dtype)
|
172 |
|