mboss commited on
Commit
d232006
·
1 Parent(s): 4ccf490

Fix device placement

Browse files
Files changed (1) hide show
  1. marble.py +26 -6
marble.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from typing import Dict
3
 
4
  import numpy as np
5
  import torch
@@ -31,10 +31,18 @@ def get_session():
31
  return _session_cache
32
 
33
 
 
 
 
 
34
  def setup_control_mlps(
35
- features: int = 1024, device: str = "cuda", dtype: torch.dtype = torch.float16
 
 
36
  ) -> Dict[str, torch.nn.Module]:
37
  ret = {}
 
 
38
  for mlp in CONTROL_MLPS:
39
  ret[mlp] = setup_control_mlp(mlp, features, device, dtype)
40
  return ret
@@ -43,9 +51,12 @@ def setup_control_mlps(
43
  def setup_control_mlp(
44
  material_parameter: str,
45
  features: int = 1024,
46
- device: str = "cuda",
47
  dtype: torch.dtype = torch.float16,
48
  ):
 
 
 
49
  net = control_mlp(features)
50
  net.load_state_dict(
51
  torch.load(os.path.join(file_dir, f"model_weights/{material_parameter}.pt"))
@@ -95,9 +106,12 @@ def download_ip_adapter():
95
 
96
 
97
  def setup_pipeline(
98
- device: str = "cuda",
99
  dtype: torch.dtype = torch.float16,
100
  ):
 
 
 
101
  download_ip_adapter()
102
 
103
  cur_block = ("up", 0, 1)
@@ -135,7 +149,10 @@ def setup_pipeline(
135
  )
136
 
137
 
138
- def get_dpt_model(device: str = "cuda", dtype: torch.dtype = torch.float16):
 
 
 
139
  image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
140
  model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas")
141
  model.to(device, dtype=dtype)
@@ -144,9 +161,12 @@ def get_dpt_model(device: str = "cuda", dtype: torch.dtype = torch.float16):
144
 
145
 
146
  def run_dpt_depth(
147
- image: Image.Image, model, processor, device: str = "cuda"
148
  ) -> Image.Image:
149
  """Run DPT depth estimation on an image."""
 
 
 
150
  # Prepare image
151
  inputs = processor(images=image, return_tensors="pt").to(device, dtype=model.dtype)
152
 
 
1
  import os
2
+ from typing import Dict, Optional
3
 
4
  import numpy as np
5
  import torch
 
31
  return _session_cache
32
 
33
 
34
+ def get_device():
35
+ return "cuda" if torch.cuda.is_available() else "cpu"
36
+
37
+
38
  def setup_control_mlps(
39
+ features: int = 1024,
40
+ device: Optional[str] = None,
41
+ dtype: torch.dtype = torch.float16,
42
  ) -> Dict[str, torch.nn.Module]:
43
  ret = {}
44
+ if device is None:
45
+ device = get_device()
46
  for mlp in CONTROL_MLPS:
47
  ret[mlp] = setup_control_mlp(mlp, features, device, dtype)
48
  return ret
 
51
  def setup_control_mlp(
52
  material_parameter: str,
53
  features: int = 1024,
54
+ device: Optional[str] = None,
55
  dtype: torch.dtype = torch.float16,
56
  ):
57
+ if device is None:
58
+ device = get_device()
59
+
60
  net = control_mlp(features)
61
  net.load_state_dict(
62
  torch.load(os.path.join(file_dir, f"model_weights/{material_parameter}.pt"))
 
106
 
107
 
108
  def setup_pipeline(
109
+ device: Optional[str] = None,
110
  dtype: torch.dtype = torch.float16,
111
  ):
112
+ if device is None:
113
+ device = get_device()
114
+
115
  download_ip_adapter()
116
 
117
  cur_block = ("up", 0, 1)
 
149
  )
150
 
151
 
152
+ def get_dpt_model(device: Optional[str] = None, dtype: torch.dtype = torch.float16):
153
+ if device is None:
154
+ device = get_device()
155
+
156
  image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
157
  model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas")
158
  model.to(device, dtype=dtype)
 
161
 
162
 
163
  def run_dpt_depth(
164
+ image: Image.Image, model, processor, device: Optional[str] = None
165
  ) -> Image.Image:
166
  """Run DPT depth estimation on an image."""
167
+ if device is None:
168
+ device = get_device()
169
+
170
  # Prepare image
171
  inputs = processor(images=image, return_tensors="pt").to(device, dtype=model.dtype)
172