# MIT License # Copyright (c) Microsoft # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # Copyright (c) [2025] [Microsoft] # SPDX-License-Identifier: MIT from typing import * import torch import torch.nn as nn from .. import models class Pipeline: """ A base class for pipelines. """ def __init__( self, models: dict[str, nn.Module] = None, ): if models is None: return self.models = models for model in self.models.values(): model.eval() @staticmethod def from_pretrained(path: str) -> "Pipeline": """ Load a pretrained model. """ import os import json is_local = os.path.exists(f"{path}/pipeline.json") if is_local: config_file = f"{path}/pipeline.json" else: from huggingface_hub import hf_hub_download config_file = hf_hub_download(path, "pipeline.json") with open(config_file, 'r') as f: args = json.load(f)['args'] _models = { k: models.from_pretrained(f"{path}/{v}") for k, v in args['models'].items() } new_pipeline = Pipeline(_models) new_pipeline._pretrained_args = args return new_pipeline @property def device(self) -> torch.device: for model in self.models.values(): if hasattr(model, 'device'): return model.device for model in self.models.values(): if hasattr(model, 'parameters'): return next(model.parameters()).device raise RuntimeError("No device found.") def to(self, device: torch.device) -> None: for model in self.models.values(): model.to(device) def cuda(self) -> None: self.to(torch.device("cuda")) def cpu(self) -> None: self.to(torch.device("cpu"))