from PIL import Image from modules.control.util import HWC3, resize_image from modules import devices from modules.shared import opts from .marigold_pipeline import MarigoldPipeline class MarigoldDetector: def __init__(self, model): self.model: MarigoldPipeline = model @classmethod def from_pretrained(cls, pretrained_model_or_path, cache_dir=None, **load_config): model = MarigoldPipeline.from_pretrained(pretrained_model_or_path, cache_dir=cache_dir, **load_config) return cls(model) def to(self, device): self.model.to(device) return self def __call__( self, input_image: Image, denoising_steps: int = 10, ensemble_size: int = 10, processing_res: int = 768, match_input_res: bool = True, color_map: str = "Spectral", output_type=None, ): self.model.to(device=devices.device, dtype=devices.dtype) res = self.model( input_image, denoising_steps=denoising_steps, ensemble_size=ensemble_size, processing_res=processing_res, match_input_res=match_input_res, color_map=color_map if color_map != 'None' else 'Spectral', batch_size=1, show_progress_bar=True, ) depth_map = res.depth_colored if color_map != 'None' else res.depth_np if opts.control_move_processor: self.model.to('cpu') if output_type == "pil": return Image.fromarray(depth_map) else: return depth_map