|
from __future__ import annotations |
|
import torch |
|
|
|
import comfy.utils |
|
from comfy.patcher_extension import WrappersMP |
|
from typing import TYPE_CHECKING, Callable, Optional |
|
if TYPE_CHECKING: |
|
from comfy.model_patcher import ModelPatcher |
|
from comfy.patcher_extension import WrapperExecutor |
|
|
|
|
|
COMPILE_KEY = "torch.compile" |
|
TORCH_COMPILE_KWARGS = "torch_compile_kwargs" |
|
|
|
|
|
def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable]) -> Callable: |
|
''' |
|
Create a wrapper that will refer to the compiled_diffusion_model. |
|
''' |
|
def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs): |
|
try: |
|
orig_modules = {} |
|
for key, value in compiled_module_dict.items(): |
|
orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key) |
|
comfy.utils.set_attr(executor.class_obj, key, value) |
|
return executor(*args, **kwargs) |
|
finally: |
|
for key, value in orig_modules.items(): |
|
comfy.utils.set_attr(executor.class_obj, key, value) |
|
return apply_torch_compile_wrapper |
|
|
|
|
|
def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str,str]]=None, |
|
mode: Optional[str]=None, fullgraph=False, dynamic: Optional[bool]=None, |
|
keys: list[str]=["diffusion_model"], *args, **kwargs): |
|
''' |
|
Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance. |
|
|
|
When keys is None, it will default to using ["diffusion_model"], compiling the whole diffusion_model. |
|
When a list of keys is provided, it will perform torch.compile on only the selected modules. |
|
''' |
|
|
|
model.remove_wrappers_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY) |
|
|
|
if not keys: |
|
keys = ["diffusion_model"] |
|
|
|
compile_kwargs = { |
|
"backend": backend, |
|
"options": options, |
|
"mode": mode, |
|
"fullgraph": fullgraph, |
|
"dynamic": dynamic, |
|
} |
|
|
|
compiled_modules = {} |
|
for key in keys: |
|
compiled_modules[key] = torch.compile( |
|
model=model.get_model_object(key), |
|
**compile_kwargs, |
|
) |
|
|
|
wrapper_func = apply_torch_compile_factory( |
|
compiled_module_dict=compiled_modules, |
|
) |
|
|
|
model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func) |
|
|
|
model.model_options[TORCH_COMPILE_KWARGS] = compile_kwargs |
|
|