Spaces:
Running
Running
| import importlib | |
| from functools import partial | |
| from pathlib import Path | |
| from typing import Any, Callable, Dict, List | |
| CATEGORIES = [ | |
| "prompt", | |
| "llm", | |
| "node", | |
| "worker", | |
| "tool", | |
| "encoder", | |
| "connector", | |
| "component", | |
| ] | |
| class Registry: | |
| """Class for module registration and retrieval.""" | |
| def __init__(self): | |
| # Initializes a mapping for different categories of modules. | |
| self.mapping = {key: {} for key in CATEGORIES} | |
| def __getattr__(self, name: str) -> Callable: | |
| if name.startswith(("register_", "get_")): | |
| prefix, category = name.split("_", 1) | |
| if category in CATEGORIES: | |
| if prefix == "register": | |
| return partial(self.register, category) | |
| elif prefix == "get": | |
| return partial(self.get, category) | |
| raise AttributeError( | |
| f"'{self.__class__.__name__}' object has no attribute '{name}'" | |
| ) | |
| def _register(self, category: str, name: str = None): | |
| """ | |
| Registers a module under a specific category. | |
| :param category: The category to register the module under. | |
| :param name: The name to register the module as. | |
| """ | |
| def wrap(module): | |
| nonlocal name | |
| name = name or module.__name__ | |
| if name in self.mapping[category]: | |
| raise ValueError( | |
| f"Module {name} [{self.mapping[category].get(name)}] already registered in category {category}. Please use a different class name." | |
| ) | |
| self.mapping.setdefault(category, {})[name] = module | |
| return module | |
| return wrap | |
| def _get(self, category: str, name: str): | |
| """ | |
| Retrieves a module from a specified category. | |
| :param category: The category to search in. | |
| :param name: The name of the module to retrieve. | |
| :raises KeyError: If the module is not found. | |
| """ | |
| try: | |
| return self.mapping[category][name] | |
| except KeyError: | |
| raise KeyError(f"Module {name} not found in category {category}") | |
| def register(self, category: str, name: str = None): | |
| """ | |
| Registers a module under a general category. | |
| :param category: The category to register the module under. | |
| :param name: Optional name to register the module as. | |
| """ | |
| return self._register(category, name) | |
| def get(self, category: str, name: str): | |
| """ | |
| Retrieves a module from a general category. | |
| :param category: The category to search in. | |
| :param name: The name of the module to retrieve. | |
| """ | |
| return self._get(category, name) | |
| def import_module(self, project_path: List[str] | str = None): | |
| """Import modules from default paths and optional project paths. | |
| Args: | |
| project_path: Optional path or list of paths to import modules from | |
| """ | |
| # Handle default paths | |
| root_path = Path(__file__).parents[1] | |
| default_path = [ | |
| root_path.joinpath("models"), | |
| root_path.joinpath("tool_system"), | |
| root_path.joinpath("services"), | |
| root_path.joinpath("memories"), | |
| root_path.joinpath("advanced_components"), | |
| root_path.joinpath("clients"), | |
| ] | |
| for path in default_path: | |
| for module in path.rglob("*.[ps][yo]"): | |
| if module.name == "workflow.py": | |
| continue | |
| module = str(module) | |
| if "__init__" in module or "base.py" in module or "entry.py" in module: | |
| continue | |
| module = "omagent_core" + module.rsplit("omagent_core", 1)[1].rsplit( | |
| ".", 1 | |
| )[0].replace("/", ".") | |
| importlib.import_module(module) | |
| # Handle project paths | |
| if project_path: | |
| if isinstance(project_path, (str, Path)): | |
| project_path = [project_path] | |
| for path in project_path: | |
| path = Path(path).absolute() | |
| project_root = path.parent | |
| for module in path.rglob("*.[ps][yo]"): | |
| module = str(module) | |
| if "__init__" in module: | |
| continue | |
| module = ( | |
| module.replace(str(project_root) + "/", "") | |
| .rsplit(".", 1)[0] | |
| .replace("/", ".") | |
| ) | |
| importlib.import_module(module) | |
| # Instantiate registry | |
| registry = Registry() | |
| if __name__ == "__main__": | |
| class TestNode: | |
| name: "TestNode" | |
| print(registry.get_node("TestNode")) | |