|
from __future__ import annotations |
|
|
|
from typing import TypedDict, Dict, Optional, Tuple |
|
from typing_extensions import override |
|
from PIL import Image |
|
from enum import Enum |
|
from abc import ABC |
|
from tqdm import tqdm |
|
from typing import TYPE_CHECKING |
|
if TYPE_CHECKING: |
|
from comfy_execution.graph import DynamicPrompt |
|
from protocol import BinaryEventTypes |
|
from comfy_api import feature_flags |
|
|
|
PreviewImageTuple = Tuple[str, Image.Image, Optional[int]] |
|
|
|
class NodeState(Enum): |
|
Pending = "pending" |
|
Running = "running" |
|
Finished = "finished" |
|
Error = "error" |
|
|
|
|
|
class NodeProgressState(TypedDict): |
|
""" |
|
A class to represent the state of a node's progress. |
|
""" |
|
|
|
state: NodeState |
|
value: float |
|
max: float |
|
|
|
|
|
class ProgressHandler(ABC): |
|
""" |
|
Abstract base class for progress handlers. |
|
Progress handlers receive progress updates and display them in various ways. |
|
""" |
|
|
|
def __init__(self, name: str): |
|
self.name = name |
|
self.enabled = True |
|
|
|
def set_registry(self, registry: "ProgressRegistry"): |
|
pass |
|
|
|
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): |
|
"""Called when a node starts processing""" |
|
pass |
|
|
|
def update_handler( |
|
self, |
|
node_id: str, |
|
value: float, |
|
max_value: float, |
|
state: NodeProgressState, |
|
prompt_id: str, |
|
image: PreviewImageTuple | None = None, |
|
): |
|
"""Called when a node's progress is updated""" |
|
pass |
|
|
|
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): |
|
"""Called when a node finishes processing""" |
|
pass |
|
|
|
def reset(self): |
|
"""Called when the progress registry is reset""" |
|
pass |
|
|
|
def enable(self): |
|
"""Enable this handler""" |
|
self.enabled = True |
|
|
|
def disable(self): |
|
"""Disable this handler""" |
|
self.enabled = False |
|
|
|
|
|
class CLIProgressHandler(ProgressHandler): |
|
""" |
|
Handler that displays progress using tqdm progress bars in the CLI. |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__("cli") |
|
self.progress_bars: Dict[str, tqdm] = {} |
|
|
|
@override |
|
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): |
|
|
|
if node_id not in self.progress_bars: |
|
self.progress_bars[node_id] = tqdm( |
|
total=state["max"], |
|
desc=f"Node {node_id}", |
|
unit="steps", |
|
leave=True, |
|
position=len(self.progress_bars), |
|
) |
|
|
|
@override |
|
def update_handler( |
|
self, |
|
node_id: str, |
|
value: float, |
|
max_value: float, |
|
state: NodeProgressState, |
|
prompt_id: str, |
|
image: PreviewImageTuple | None = None, |
|
): |
|
|
|
if node_id not in self.progress_bars: |
|
self.progress_bars[node_id] = tqdm( |
|
total=max_value, |
|
desc=f"Node {node_id}", |
|
unit="steps", |
|
leave=True, |
|
position=len(self.progress_bars), |
|
) |
|
self.progress_bars[node_id].update(value) |
|
else: |
|
|
|
if max_value != self.progress_bars[node_id].total: |
|
self.progress_bars[node_id].total = max_value |
|
|
|
current_position = self.progress_bars[node_id].n |
|
update_amount = value - current_position |
|
if update_amount > 0: |
|
self.progress_bars[node_id].update(update_amount) |
|
|
|
@override |
|
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): |
|
|
|
if node_id in self.progress_bars: |
|
|
|
remaining = state["max"] - self.progress_bars[node_id].n |
|
if remaining > 0: |
|
self.progress_bars[node_id].update(remaining) |
|
self.progress_bars[node_id].close() |
|
del self.progress_bars[node_id] |
|
|
|
@override |
|
def reset(self): |
|
|
|
for bar in self.progress_bars.values(): |
|
bar.close() |
|
self.progress_bars.clear() |
|
|
|
|
|
class WebUIProgressHandler(ProgressHandler): |
|
""" |
|
Handler that sends progress updates to the WebUI via WebSockets. |
|
""" |
|
|
|
def __init__(self, server_instance): |
|
super().__init__("webui") |
|
self.server_instance = server_instance |
|
|
|
def set_registry(self, registry: "ProgressRegistry"): |
|
self.registry = registry |
|
|
|
def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]): |
|
"""Send the current progress state to the client""" |
|
if self.server_instance is None: |
|
return |
|
|
|
|
|
active_nodes = { |
|
node_id: { |
|
"value": state["value"], |
|
"max": state["max"], |
|
"state": state["state"].value, |
|
"node_id": node_id, |
|
"prompt_id": prompt_id, |
|
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id), |
|
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id), |
|
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id), |
|
} |
|
for node_id, state in nodes.items() |
|
if state["state"] != NodeState.Pending |
|
} |
|
|
|
|
|
self.server_instance.send_sync( |
|
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes} |
|
) |
|
|
|
@override |
|
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): |
|
|
|
if self.registry: |
|
self._send_progress_state(prompt_id, self.registry.nodes) |
|
|
|
@override |
|
def update_handler( |
|
self, |
|
node_id: str, |
|
value: float, |
|
max_value: float, |
|
state: NodeProgressState, |
|
prompt_id: str, |
|
image: PreviewImageTuple | None = None, |
|
): |
|
|
|
if self.registry: |
|
self._send_progress_state(prompt_id, self.registry.nodes) |
|
if image: |
|
|
|
if feature_flags.supports_feature( |
|
self.server_instance.sockets_metadata, |
|
self.server_instance.client_id, |
|
"supports_preview_metadata", |
|
): |
|
metadata = { |
|
"node_id": node_id, |
|
"prompt_id": prompt_id, |
|
"display_node_id": self.registry.dynprompt.get_display_node_id( |
|
node_id |
|
), |
|
"parent_node_id": self.registry.dynprompt.get_parent_node_id( |
|
node_id |
|
), |
|
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id), |
|
} |
|
self.server_instance.send_sync( |
|
BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, |
|
(image, metadata), |
|
self.server_instance.client_id, |
|
) |
|
|
|
@override |
|
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): |
|
|
|
if self.registry: |
|
self._send_progress_state(prompt_id, self.registry.nodes) |
|
|
|
class ProgressRegistry: |
|
""" |
|
Registry that maintains node progress state and notifies registered handlers. |
|
""" |
|
|
|
def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"): |
|
self.prompt_id = prompt_id |
|
self.dynprompt = dynprompt |
|
self.nodes: Dict[str, NodeProgressState] = {} |
|
self.handlers: Dict[str, ProgressHandler] = {} |
|
|
|
def register_handler(self, handler: ProgressHandler) -> None: |
|
"""Register a progress handler""" |
|
self.handlers[handler.name] = handler |
|
|
|
def unregister_handler(self, handler_name: str) -> None: |
|
"""Unregister a progress handler""" |
|
if handler_name in self.handlers: |
|
|
|
self.handlers[handler_name].reset() |
|
del self.handlers[handler_name] |
|
|
|
def enable_handler(self, handler_name: str) -> None: |
|
"""Enable a progress handler""" |
|
if handler_name in self.handlers: |
|
self.handlers[handler_name].enable() |
|
|
|
def disable_handler(self, handler_name: str) -> None: |
|
"""Disable a progress handler""" |
|
if handler_name in self.handlers: |
|
self.handlers[handler_name].disable() |
|
|
|
def ensure_entry(self, node_id: str) -> NodeProgressState: |
|
"""Ensure a node entry exists""" |
|
if node_id not in self.nodes: |
|
self.nodes[node_id] = NodeProgressState( |
|
state=NodeState.Pending, value=0, max=1 |
|
) |
|
return self.nodes[node_id] |
|
|
|
def start_progress(self, node_id: str) -> None: |
|
"""Start progress tracking for a node""" |
|
entry = self.ensure_entry(node_id) |
|
entry["state"] = NodeState.Running |
|
entry["value"] = 0.0 |
|
entry["max"] = 1.0 |
|
|
|
|
|
for handler in self.handlers.values(): |
|
if handler.enabled: |
|
handler.start_handler(node_id, entry, self.prompt_id) |
|
|
|
def update_progress( |
|
self, node_id: str, value: float, max_value: float, image: PreviewImageTuple | None = None |
|
) -> None: |
|
"""Update progress for a node""" |
|
entry = self.ensure_entry(node_id) |
|
entry["state"] = NodeState.Running |
|
entry["value"] = value |
|
entry["max"] = max_value |
|
|
|
|
|
for handler in self.handlers.values(): |
|
if handler.enabled: |
|
handler.update_handler( |
|
node_id, value, max_value, entry, self.prompt_id, image |
|
) |
|
|
|
def finish_progress(self, node_id: str) -> None: |
|
"""Finish progress tracking for a node""" |
|
entry = self.ensure_entry(node_id) |
|
entry["state"] = NodeState.Finished |
|
entry["value"] = entry["max"] |
|
|
|
|
|
for handler in self.handlers.values(): |
|
if handler.enabled: |
|
handler.finish_handler(node_id, entry, self.prompt_id) |
|
|
|
def reset_handlers(self) -> None: |
|
"""Reset all handlers""" |
|
for handler in self.handlers.values(): |
|
handler.reset() |
|
|
|
|
|
global_progress_registry: ProgressRegistry | None = None |
|
|
|
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None: |
|
global global_progress_registry |
|
|
|
|
|
if global_progress_registry is not None: |
|
global_progress_registry.reset_handlers() |
|
|
|
|
|
global_progress_registry = ProgressRegistry(prompt_id, dynprompt) |
|
|
|
|
|
def add_progress_handler(handler: ProgressHandler) -> None: |
|
registry = get_progress_state() |
|
handler.set_registry(registry) |
|
registry.register_handler(handler) |
|
|
|
|
|
def get_progress_state() -> ProgressRegistry: |
|
global global_progress_registry |
|
if global_progress_registry is None: |
|
from comfy_execution.graph import DynamicPrompt |
|
|
|
global_progress_registry = ProgressRegistry( |
|
prompt_id="", dynprompt=DynamicPrompt({}) |
|
) |
|
return global_progress_registry |
|
|