|
import asyncio |
|
import concurrent.futures |
|
import contextvars |
|
import functools |
|
import inspect |
|
import logging |
|
import os |
|
import textwrap |
|
import threading |
|
from enum import Enum |
|
from typing import Optional, Type, get_origin, get_args |
|
|
|
|
|
class TypeTracker: |
|
"""Tracks types discovered during stub generation for automatic import generation.""" |
|
|
|
def __init__(self): |
|
self.discovered_types = {} |
|
self.builtin_types = { |
|
"Any", |
|
"Dict", |
|
"List", |
|
"Optional", |
|
"Tuple", |
|
"Union", |
|
"Set", |
|
"Sequence", |
|
"cast", |
|
"NamedTuple", |
|
"str", |
|
"int", |
|
"float", |
|
"bool", |
|
"None", |
|
"bytes", |
|
"object", |
|
"type", |
|
"dict", |
|
"list", |
|
"tuple", |
|
"set", |
|
} |
|
self.already_imported = ( |
|
set() |
|
) |
|
|
|
def track_type(self, annotation): |
|
"""Track a type annotation and record its module/import info.""" |
|
if annotation is None or annotation is type(None): |
|
return |
|
|
|
|
|
type_name = getattr(annotation, "__name__", None) |
|
if type_name and ( |
|
type_name in self.builtin_types or type_name in self.already_imported |
|
): |
|
return |
|
|
|
|
|
module = getattr(annotation, "__module__", None) |
|
qualname = getattr(annotation, "__qualname__", type_name or "") |
|
|
|
|
|
if module == "typing": |
|
return |
|
|
|
|
|
if module == "types" and type_name in ("UnionType", "GenericAlias"): |
|
return |
|
|
|
if module and module not in ["builtins", "__main__"]: |
|
|
|
if type_name: |
|
self.discovered_types[type_name] = (module, qualname) |
|
|
|
def get_imports(self, main_module_name: str) -> list[str]: |
|
"""Generate import statements for all discovered types.""" |
|
imports = [] |
|
imports_by_module = {} |
|
|
|
for type_name, (module, qualname) in sorted(self.discovered_types.items()): |
|
|
|
if main_module_name and module == main_module_name: |
|
continue |
|
|
|
if module not in imports_by_module: |
|
imports_by_module[module] = [] |
|
if type_name not in imports_by_module[module]: |
|
imports_by_module[module].append(type_name) |
|
|
|
|
|
for module, types in sorted(imports_by_module.items()): |
|
if len(types) == 1: |
|
imports.append(f"from {module} import {types[0]}") |
|
else: |
|
imports.append(f"from {module} import {', '.join(sorted(set(types)))}") |
|
|
|
return imports |
|
|
|
|
|
class AsyncToSyncConverter: |
|
""" |
|
Provides utilities to convert async classes to sync classes with proper type hints. |
|
""" |
|
|
|
_thread_pool: Optional[concurrent.futures.ThreadPoolExecutor] = None |
|
_thread_pool_lock = threading.Lock() |
|
_thread_pool_initialized = False |
|
|
|
@classmethod |
|
def get_thread_pool(cls, max_workers=None) -> concurrent.futures.ThreadPoolExecutor: |
|
"""Get or create the shared thread pool with proper thread-safe initialization.""" |
|
|
|
if cls._thread_pool_initialized: |
|
assert cls._thread_pool is not None, "Thread pool should be initialized" |
|
return cls._thread_pool |
|
|
|
|
|
with cls._thread_pool_lock: |
|
if not cls._thread_pool_initialized: |
|
cls._thread_pool = concurrent.futures.ThreadPoolExecutor( |
|
max_workers=max_workers, thread_name_prefix="async_to_sync_" |
|
) |
|
cls._thread_pool_initialized = True |
|
|
|
|
|
assert cls._thread_pool is not None |
|
return cls._thread_pool |
|
|
|
@classmethod |
|
def run_async_in_thread(cls, coro_func, *args, **kwargs): |
|
""" |
|
Run an async function in a separate thread from the thread pool. |
|
Blocks until the async function completes. |
|
Properly propagates contextvars between threads and manages event loops. |
|
""" |
|
|
|
context = contextvars.copy_context() |
|
|
|
|
|
result_container: dict = {"result": None, "exception": None} |
|
|
|
|
|
def run_in_thread(): |
|
|
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
|
|
try: |
|
|
|
async def run_with_context(): |
|
|
|
return await coro_func(*args, **kwargs) |
|
|
|
|
|
|
|
result = context.run(loop.run_until_complete, run_with_context()) |
|
result_container["result"] = result |
|
except Exception as e: |
|
|
|
result_container["exception"] = e |
|
finally: |
|
|
|
try: |
|
|
|
pending = asyncio.all_tasks(loop) |
|
for task in pending: |
|
task.cancel() |
|
|
|
|
|
if pending: |
|
loop.run_until_complete( |
|
asyncio.gather(*pending, return_exceptions=True) |
|
) |
|
except Exception: |
|
pass |
|
|
|
|
|
loop.close() |
|
|
|
|
|
asyncio.set_event_loop(None) |
|
|
|
|
|
thread_pool = cls.get_thread_pool() |
|
future = thread_pool.submit(run_in_thread) |
|
future.result() |
|
|
|
|
|
if result_container["exception"] is not None: |
|
raise result_container["exception"] |
|
|
|
return result_container["result"] |
|
|
|
@classmethod |
|
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type: |
|
""" |
|
Creates a new class with synchronous versions of all async methods. |
|
|
|
Args: |
|
async_class: The async class to convert |
|
thread_pool_size: Size of thread pool to use |
|
|
|
Returns: |
|
A new class with sync versions of all async methods |
|
""" |
|
sync_class_name = "ComfyAPISyncStub" |
|
cls.get_thread_pool(thread_pool_size) |
|
|
|
|
|
sync_class_dict = { |
|
"__doc__": async_class.__doc__, |
|
"__module__": async_class.__module__, |
|
"__qualname__": sync_class_name, |
|
"__orig_class__": async_class, |
|
} |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
self._async_instance = async_class(*args, **kwargs) |
|
|
|
|
|
|
|
all_annotations = {} |
|
for base_class in reversed(inspect.getmro(async_class)): |
|
if hasattr(base_class, "__annotations__"): |
|
all_annotations.update(base_class.__annotations__) |
|
|
|
|
|
for attr_name, attr_type in all_annotations.items(): |
|
if hasattr(self._async_instance, attr_name): |
|
|
|
attr = getattr(self._async_instance, attr_name) |
|
|
|
if hasattr(attr, "__class__"): |
|
from comfy_api.internal.singleton import ProxiedSingleton |
|
|
|
if isinstance(attr, ProxiedSingleton): |
|
|
|
try: |
|
sync_attr_class = cls.create_sync_class(attr.__class__) |
|
|
|
sync_attr = object.__new__(sync_attr_class) |
|
sync_attr._async_instance = attr |
|
setattr(self, attr_name, sync_attr) |
|
except Exception: |
|
|
|
setattr(self, attr_name, attr) |
|
else: |
|
|
|
setattr(self, attr_name, attr) |
|
else: |
|
|
|
|
|
if isinstance(attr_type, type): |
|
|
|
if hasattr(async_class, attr_type.__name__): |
|
inner_class = getattr(async_class, attr_type.__name__) |
|
from comfy_api.internal.singleton import ProxiedSingleton |
|
|
|
|
|
try: |
|
|
|
if issubclass(inner_class, ProxiedSingleton): |
|
async_instance = inner_class.get_instance() |
|
else: |
|
async_instance = inner_class() |
|
|
|
|
|
sync_attr_class = cls.create_sync_class(inner_class) |
|
sync_attr = object.__new__(sync_attr_class) |
|
sync_attr._async_instance = async_instance |
|
setattr(self, attr_name, sync_attr) |
|
|
|
setattr(self._async_instance, attr_name, async_instance) |
|
except Exception as e: |
|
logging.warning( |
|
f"Failed to create instance for {attr_name}: {e}" |
|
) |
|
|
|
|
|
for name, attr in inspect.getmembers(self._async_instance): |
|
if name.startswith("_") or hasattr(self, name): |
|
continue |
|
|
|
|
|
|
|
if isinstance(attr, object) and not isinstance( |
|
attr, (str, int, float, bool, list, dict, tuple) |
|
): |
|
from comfy_api.internal.singleton import ProxiedSingleton |
|
|
|
if isinstance(attr, ProxiedSingleton): |
|
|
|
try: |
|
sync_attr_class = cls.create_sync_class(attr.__class__) |
|
|
|
sync_attr = object.__new__(sync_attr_class) |
|
sync_attr._async_instance = attr |
|
setattr(self, name, sync_attr) |
|
except Exception: |
|
|
|
setattr(self, name, attr) |
|
|
|
sync_class_dict["__init__"] = __init__ |
|
|
|
|
|
for name, method in inspect.getmembers( |
|
async_class, predicate=inspect.isfunction |
|
): |
|
if name.startswith("_"): |
|
continue |
|
|
|
|
|
if inspect.iscoroutinefunction(method): |
|
|
|
@functools.wraps(method) |
|
def sync_method(self, *args, _method_name=name, **kwargs): |
|
async_method = getattr(self._async_instance, _method_name) |
|
return AsyncToSyncConverter.run_async_in_thread( |
|
async_method, *args, **kwargs |
|
) |
|
|
|
|
|
sync_class_dict[name] = sync_method |
|
else: |
|
|
|
@functools.wraps(method) |
|
def proxy_method(self, *args, _method_name=name, **kwargs): |
|
method = getattr(self._async_instance, _method_name) |
|
return method(*args, **kwargs) |
|
|
|
|
|
sync_class_dict[name] = proxy_method |
|
|
|
|
|
for name, prop in inspect.getmembers( |
|
async_class, lambda x: isinstance(x, property) |
|
): |
|
|
|
def make_property(name, prop_obj): |
|
def getter(self): |
|
value = getattr(self._async_instance, name) |
|
if inspect.iscoroutinefunction(value): |
|
|
|
def sync_fn(*args, **kwargs): |
|
return AsyncToSyncConverter.run_async_in_thread( |
|
value, *args, **kwargs |
|
) |
|
|
|
return sync_fn |
|
return value |
|
|
|
def setter(self, value): |
|
setattr(self._async_instance, name, value) |
|
|
|
return property(getter, setter if prop_obj.fset else None) |
|
|
|
sync_class_dict[name] = make_property(name, prop) |
|
|
|
|
|
sync_class = type(sync_class_name, (object,), sync_class_dict) |
|
|
|
return sync_class |
|
|
|
@classmethod |
|
def _format_type_annotation( |
|
cls, annotation, type_tracker: Optional[TypeTracker] = None |
|
) -> str: |
|
"""Convert a type annotation to its string representation for stub files.""" |
|
if ( |
|
annotation is inspect.Parameter.empty |
|
or annotation is inspect.Signature.empty |
|
): |
|
return "Any" |
|
|
|
|
|
if annotation is type(None): |
|
return "None" |
|
|
|
|
|
if type_tracker: |
|
type_tracker.track_type(annotation) |
|
|
|
|
|
try: |
|
origin = get_origin(annotation) |
|
args = get_args(annotation) |
|
|
|
if origin is not None: |
|
|
|
if type_tracker: |
|
type_tracker.track_type(origin) |
|
|
|
|
|
origin_name = getattr(origin, "__name__", str(origin)) |
|
if "." in origin_name: |
|
origin_name = origin_name.split(".")[-1] |
|
|
|
|
|
|
|
if str(origin) == "<class 'types.UnionType'>" or origin_name == "UnionType": |
|
origin_name = "Union" |
|
|
|
|
|
if args: |
|
formatted_args = [] |
|
for arg in args: |
|
|
|
if type_tracker: |
|
type_tracker.track_type(arg) |
|
formatted_args.append(cls._format_type_annotation(arg, type_tracker)) |
|
return f"{origin_name}[{', '.join(formatted_args)}]" |
|
else: |
|
return origin_name |
|
except (AttributeError, TypeError): |
|
|
|
pass |
|
|
|
|
|
if hasattr(annotation, "__origin__") and hasattr(annotation, "__args__"): |
|
origin = annotation.__origin__ |
|
origin_name = ( |
|
origin.__name__ |
|
if hasattr(origin, "__name__") |
|
else str(origin).split("'")[1] |
|
) |
|
|
|
|
|
args = [] |
|
for arg in annotation.__args__: |
|
args.append(cls._format_type_annotation(arg, type_tracker)) |
|
|
|
return f"{origin_name}[{', '.join(args)}]" |
|
|
|
|
|
if hasattr(annotation, "__name__"): |
|
return annotation.__name__ |
|
|
|
|
|
if hasattr(annotation, "__module__") and hasattr(annotation, "__qualname__"): |
|
|
|
return annotation.__qualname__ |
|
|
|
|
|
type_str = str(annotation) |
|
|
|
|
|
if type_str.startswith("<class '") and type_str.endswith("'>"): |
|
type_str = type_str[8:-2] |
|
|
|
|
|
for prefix in ["typing.", "builtins.", "types."]: |
|
if type_str.startswith(prefix): |
|
type_str = type_str[len(prefix) :] |
|
|
|
|
|
if type_str in ("_empty", "inspect._empty"): |
|
return "None" |
|
|
|
|
|
if type_str == "NoneType": |
|
return "None" |
|
|
|
return type_str |
|
|
|
@classmethod |
|
def _extract_coroutine_return_type(cls, annotation): |
|
"""Extract the actual return type from a Coroutine annotation.""" |
|
if hasattr(annotation, "__args__") and len(annotation.__args__) > 2: |
|
|
|
return annotation.__args__[2] |
|
return annotation |
|
|
|
@classmethod |
|
def _format_parameter_default(cls, default_value) -> str: |
|
"""Format a parameter's default value for stub files.""" |
|
if default_value is inspect.Parameter.empty: |
|
return "" |
|
elif default_value is None: |
|
return " = None" |
|
elif isinstance(default_value, bool): |
|
return f" = {default_value}" |
|
elif default_value == {}: |
|
return " = {}" |
|
elif default_value == []: |
|
return " = []" |
|
else: |
|
return f" = {default_value}" |
|
|
|
@classmethod |
|
def _format_method_parameters( |
|
cls, |
|
sig: inspect.Signature, |
|
skip_self: bool = True, |
|
type_hints: Optional[dict] = None, |
|
type_tracker: Optional[TypeTracker] = None, |
|
) -> str: |
|
"""Format method parameters for stub files.""" |
|
params = [] |
|
if type_hints is None: |
|
type_hints = {} |
|
|
|
for i, (param_name, param) in enumerate(sig.parameters.items()): |
|
if i == 0 and param_name == "self" and skip_self: |
|
params.append("self") |
|
else: |
|
|
|
annotation = type_hints.get(param_name, param.annotation) |
|
type_str = cls._format_type_annotation(annotation, type_tracker) |
|
|
|
|
|
default_str = cls._format_parameter_default(param.default) |
|
|
|
|
|
if annotation is inspect.Parameter.empty: |
|
params.append(f"{param_name}: Any{default_str}") |
|
else: |
|
params.append(f"{param_name}: {type_str}{default_str}") |
|
|
|
return ", ".join(params) |
|
|
|
@classmethod |
|
def _generate_method_signature( |
|
cls, |
|
method_name: str, |
|
method, |
|
is_async: bool = False, |
|
type_tracker: Optional[TypeTracker] = None, |
|
) -> str: |
|
"""Generate a complete method signature for stub files.""" |
|
sig = inspect.signature(method) |
|
|
|
|
|
try: |
|
from typing import get_type_hints |
|
type_hints = get_type_hints(method) |
|
except Exception: |
|
|
|
type_hints = {} |
|
|
|
|
|
return_annotation = type_hints.get('return', sig.return_annotation) |
|
if is_async and inspect.iscoroutinefunction(method): |
|
return_annotation = cls._extract_coroutine_return_type(return_annotation) |
|
|
|
|
|
params_str = cls._format_method_parameters(sig, type_hints=type_hints, type_tracker=type_tracker) |
|
|
|
|
|
return_type = cls._format_type_annotation(return_annotation, type_tracker) |
|
if return_annotation is inspect.Signature.empty: |
|
return_type = "None" |
|
|
|
return f"def {method_name}({params_str}) -> {return_type}: ..." |
|
|
|
@classmethod |
|
def _generate_imports( |
|
cls, async_class: Type, type_tracker: TypeTracker |
|
) -> list[str]: |
|
"""Generate import statements for the stub file.""" |
|
imports = [] |
|
|
|
|
|
imports.append( |
|
"from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple" |
|
) |
|
|
|
|
|
if async_class.__module__ != "builtins": |
|
module = inspect.getmodule(async_class) |
|
additional_types = [] |
|
|
|
if module: |
|
|
|
module_all = getattr(module, "__all__", None) |
|
|
|
for name, obj in sorted(inspect.getmembers(module)): |
|
if isinstance(obj, type): |
|
|
|
|
|
if module_all is not None and name not in module_all: |
|
|
|
if name not in type_tracker.discovered_types: |
|
continue |
|
|
|
|
|
if issubclass(obj, tuple) and hasattr(obj, "_fields"): |
|
additional_types.append(name) |
|
|
|
type_tracker.already_imported.add(name) |
|
|
|
elif issubclass(obj, Enum) and name != "Enum": |
|
additional_types.append(name) |
|
|
|
type_tracker.already_imported.add(name) |
|
|
|
if additional_types: |
|
type_imports = ", ".join([async_class.__name__] + additional_types) |
|
imports.append(f"from {async_class.__module__} import {type_imports}") |
|
else: |
|
imports.append( |
|
f"from {async_class.__module__} import {async_class.__name__}" |
|
) |
|
|
|
|
|
|
|
imports.extend( |
|
type_tracker.get_imports(main_module_name=async_class.__module__) |
|
) |
|
|
|
|
|
if hasattr(inspect.getmodule(async_class), "__name__"): |
|
module_name = inspect.getmodule(async_class).__name__ |
|
if "." in module_name: |
|
base_module = module_name.split(".")[0] |
|
|
|
if not any(imp.startswith(f"from {base_module}") for imp in imports): |
|
imports.append(f"import {base_module}") |
|
|
|
return imports |
|
|
|
@classmethod |
|
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]: |
|
"""Extract class attributes that are classes themselves.""" |
|
class_attributes = [] |
|
|
|
|
|
for name, attr in sorted(inspect.getmembers(async_class)): |
|
if isinstance(attr, type) and not name.startswith("_"): |
|
class_attributes.append((name, attr)) |
|
elif ( |
|
hasattr(async_class, "__annotations__") |
|
and name in async_class.__annotations__ |
|
): |
|
annotation = async_class.__annotations__[name] |
|
if isinstance(annotation, type): |
|
class_attributes.append((name, annotation)) |
|
|
|
return class_attributes |
|
|
|
@classmethod |
|
def _generate_inner_class_stub( |
|
cls, |
|
name: str, |
|
attr: Type, |
|
indent: str = " ", |
|
type_tracker: Optional[TypeTracker] = None, |
|
) -> list[str]: |
|
"""Generate stub for an inner class.""" |
|
stub_lines = [] |
|
stub_lines.append(f"{indent}class {name}Sync:") |
|
|
|
|
|
if hasattr(attr, "__doc__") and attr.__doc__: |
|
stub_lines.extend( |
|
cls._format_docstring_for_stub(attr.__doc__, f"{indent} ") |
|
) |
|
|
|
|
|
if hasattr(attr, "__init__"): |
|
try: |
|
init_method = getattr(attr, "__init__") |
|
init_sig = inspect.signature(init_method) |
|
|
|
|
|
try: |
|
from typing import get_type_hints |
|
init_hints = get_type_hints(init_method) |
|
except Exception: |
|
init_hints = {} |
|
|
|
|
|
params_str = cls._format_method_parameters( |
|
init_sig, type_hints=init_hints, type_tracker=type_tracker |
|
) |
|
|
|
if hasattr(init_method, "__doc__") and init_method.__doc__: |
|
stub_lines.extend( |
|
cls._format_docstring_for_stub( |
|
init_method.__doc__, f"{indent} " |
|
) |
|
) |
|
stub_lines.append( |
|
f"{indent} def __init__({params_str}) -> None: ..." |
|
) |
|
except (ValueError, TypeError): |
|
stub_lines.append( |
|
f"{indent} def __init__(self, *args, **kwargs) -> None: ..." |
|
) |
|
|
|
|
|
has_methods = False |
|
for method_name, method in sorted( |
|
inspect.getmembers(attr, predicate=inspect.isfunction) |
|
): |
|
if method_name.startswith("_"): |
|
continue |
|
|
|
has_methods = True |
|
try: |
|
|
|
if method.__doc__: |
|
stub_lines.extend( |
|
cls._format_docstring_for_stub(method.__doc__, f"{indent} ") |
|
) |
|
|
|
method_sig = cls._generate_method_signature( |
|
method_name, method, is_async=True, type_tracker=type_tracker |
|
) |
|
stub_lines.append(f"{indent} {method_sig}") |
|
except (ValueError, TypeError): |
|
stub_lines.append( |
|
f"{indent} def {method_name}(self, *args, **kwargs): ..." |
|
) |
|
|
|
if not has_methods: |
|
stub_lines.append(f"{indent} pass") |
|
|
|
return stub_lines |
|
|
|
@classmethod |
|
def _format_docstring_for_stub( |
|
cls, docstring: str, indent: str = " " |
|
) -> list[str]: |
|
"""Format a docstring for inclusion in a stub file with proper indentation.""" |
|
if not docstring: |
|
return [] |
|
|
|
|
|
dedented = textwrap.dedent(docstring).strip() |
|
|
|
|
|
lines = dedented.split("\n") |
|
|
|
|
|
result = [] |
|
result.append(f'{indent}"""') |
|
|
|
for line in lines: |
|
if line.strip(): |
|
result.append(f"{indent}{line}") |
|
else: |
|
result.append("") |
|
|
|
result.append(f'{indent}"""') |
|
return result |
|
|
|
@classmethod |
|
def _post_process_stub_content(cls, stub_content: list[str]) -> list[str]: |
|
"""Post-process stub content to fix any remaining issues.""" |
|
processed = [] |
|
|
|
for line in stub_content: |
|
|
|
if line.startswith(("from ", "import ")): |
|
processed.append(line) |
|
continue |
|
|
|
|
|
if ( |
|
line.strip().startswith("def ") |
|
and line.strip().endswith(": ...") |
|
and ") -> " not in line |
|
): |
|
|
|
line = line.replace(": ...", " -> None: ...") |
|
|
|
processed.append(line) |
|
|
|
return processed |
|
|
|
@classmethod |
|
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None: |
|
""" |
|
Generate a .pyi stub file for the sync class to help IDEs with type checking. |
|
""" |
|
try: |
|
|
|
if async_class.__module__ == "__main__": |
|
return |
|
|
|
module = inspect.getmodule(async_class) |
|
if not module: |
|
return |
|
|
|
module_path = module.__file__ |
|
if not module_path: |
|
return |
|
|
|
|
|
module_dir = os.path.dirname(module_path) |
|
stub_dir = os.path.join(module_dir, "generated") |
|
|
|
|
|
os.makedirs(stub_dir, exist_ok=True) |
|
|
|
module_name = os.path.basename(module_path) |
|
if module_name.endswith(".py"): |
|
module_name = module_name[:-3] |
|
|
|
sync_stub_path = os.path.join(stub_dir, f"{sync_class.__name__}.pyi") |
|
|
|
|
|
type_tracker = TypeTracker() |
|
|
|
stub_content = [] |
|
|
|
|
|
|
|
imports_placeholder_index = len(stub_content) |
|
stub_content.append("") |
|
|
|
|
|
stub_content.append(f"class {sync_class.__name__}:") |
|
|
|
|
|
if async_class.__doc__: |
|
stub_content.extend( |
|
cls._format_docstring_for_stub(async_class.__doc__, " ") |
|
) |
|
|
|
|
|
try: |
|
init_method = async_class.__init__ |
|
init_signature = inspect.signature(init_method) |
|
|
|
|
|
try: |
|
from typing import get_type_hints |
|
init_hints = get_type_hints(init_method) |
|
except Exception: |
|
init_hints = {} |
|
|
|
|
|
params_str = cls._format_method_parameters( |
|
init_signature, type_hints=init_hints, type_tracker=type_tracker |
|
) |
|
|
|
if hasattr(init_method, "__doc__") and init_method.__doc__: |
|
stub_content.extend( |
|
cls._format_docstring_for_stub(init_method.__doc__, " ") |
|
) |
|
stub_content.append(f" def __init__({params_str}) -> None: ...") |
|
except (ValueError, TypeError): |
|
stub_content.append( |
|
" def __init__(self, *args, **kwargs) -> None: ..." |
|
) |
|
|
|
stub_content.append("") |
|
|
|
|
|
class_attributes = cls._get_class_attributes(async_class) |
|
|
|
|
|
for name, attr in class_attributes: |
|
inner_class_stub = cls._generate_inner_class_stub( |
|
name, attr, type_tracker=type_tracker |
|
) |
|
stub_content.extend(inner_class_stub) |
|
stub_content.append("") |
|
|
|
|
|
processed_methods = set() |
|
for name, method in sorted( |
|
inspect.getmembers(async_class, predicate=inspect.isfunction) |
|
): |
|
if name.startswith("_") or name in processed_methods: |
|
continue |
|
|
|
processed_methods.add(name) |
|
|
|
try: |
|
method_sig = cls._generate_method_signature( |
|
name, method, is_async=True, type_tracker=type_tracker |
|
) |
|
|
|
|
|
if method.__doc__: |
|
stub_content.extend( |
|
cls._format_docstring_for_stub(method.__doc__, " ") |
|
) |
|
|
|
stub_content.append(f" {method_sig}") |
|
|
|
stub_content.append("") |
|
|
|
except (ValueError, TypeError): |
|
|
|
stub_content.append(f" def {name}(self, *args, **kwargs): ...") |
|
stub_content.append("") |
|
|
|
|
|
for name, prop in sorted( |
|
inspect.getmembers(async_class, lambda x: isinstance(x, property)) |
|
): |
|
stub_content.append(" @property") |
|
stub_content.append(f" def {name}(self) -> Any: ...") |
|
if prop.fset: |
|
stub_content.append(f" @{name}.setter") |
|
stub_content.append( |
|
f" def {name}(self, value: Any) -> None: ..." |
|
) |
|
stub_content.append("") |
|
|
|
|
|
|
|
attribute_mappings = {} |
|
|
|
|
|
|
|
all_annotations = {} |
|
for base_class in reversed(inspect.getmro(async_class)): |
|
if hasattr(base_class, "__annotations__"): |
|
all_annotations.update(base_class.__annotations__) |
|
|
|
for attr_name, attr_type in sorted(all_annotations.items()): |
|
for class_name, class_type in class_attributes: |
|
|
|
if ( |
|
attr_type == class_type |
|
or (hasattr(attr_type, "__name__") and attr_type.__name__ == class_name) |
|
or (isinstance(attr_type, str) and attr_type == class_name) |
|
): |
|
attribute_mappings[class_name] = attr_name |
|
|
|
|
|
|
|
|
|
for class_name, class_type in class_attributes: |
|
|
|
attr_name = attribute_mappings.get(class_name, class_name) |
|
|
|
|
|
stub_content.append(f" {attr_name}: {class_name}Sync") |
|
|
|
stub_content.append("") |
|
|
|
|
|
imports = cls._generate_imports(async_class, type_tracker) |
|
|
|
|
|
seen = set() |
|
unique_imports = [] |
|
for imp in imports: |
|
if imp not in seen: |
|
seen.add(imp) |
|
unique_imports.append(imp) |
|
else: |
|
logging.warning(f"Duplicate import detected: {imp}") |
|
|
|
|
|
stub_content[imports_placeholder_index : imports_placeholder_index + 1] = ( |
|
unique_imports |
|
) |
|
|
|
|
|
stub_content = cls._post_process_stub_content(stub_content) |
|
|
|
|
|
with open(sync_stub_path, "w") as f: |
|
f.write("\n".join(stub_content)) |
|
|
|
logging.info(f"Generated stub file: {sync_stub_path}") |
|
|
|
except Exception as e: |
|
|
|
logging.error( |
|
f"Error generating stub file for {sync_class.__name__}: {str(e)}" |
|
) |
|
import traceback |
|
|
|
logging.error(traceback.format_exc()) |
|
|
|
|
|
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type: |
|
""" |
|
Creates a sync version of an async class |
|
|
|
Args: |
|
async_class: The async class to convert |
|
thread_pool_size: Size of thread pool to use |
|
|
|
Returns: |
|
A new class with sync versions of all async methods |
|
""" |
|
return AsyncToSyncConverter.create_sync_class(async_class, thread_pool_size) |
|
|