| | from __future__ import annotations |
| |
|
| | from typing import TypedDict, Dict, Optional, Tuple |
| | from typing_extensions import override |
| | from PIL import Image |
| | from enum import Enum |
| | from abc import ABC |
| | from tqdm import tqdm |
| | from typing import TYPE_CHECKING |
| | if TYPE_CHECKING: |
| | from comfy_execution.graph import DynamicPrompt |
| | from protocol import BinaryEventTypes |
| | from comfy_api import feature_flags |
| |
|
| | PreviewImageTuple = Tuple[str, Image.Image, Optional[int]] |
| |
|
| | class NodeState(Enum): |
| | Pending = "pending" |
| | Running = "running" |
| | Finished = "finished" |
| | Error = "error" |
| |
|
| |
|
| | class NodeProgressState(TypedDict): |
| | """ |
| | A class to represent the state of a node's progress. |
| | """ |
| |
|
| | state: NodeState |
| | value: float |
| | max: float |
| |
|
| |
|
| | class ProgressHandler(ABC): |
| | """ |
| | Abstract base class for progress handlers. |
| | Progress handlers receive progress updates and display them in various ways. |
| | """ |
| |
|
| | def __init__(self, name: str): |
| | self.name = name |
| | self.enabled = True |
| |
|
| | def set_registry(self, registry: "ProgressRegistry"): |
| | pass |
| |
|
| | def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): |
| | """Called when a node starts processing""" |
| | pass |
| |
|
| | def update_handler( |
| | self, |
| | node_id: str, |
| | value: float, |
| | max_value: float, |
| | state: NodeProgressState, |
| | prompt_id: str, |
| | image: PreviewImageTuple | None = None, |
| | ): |
| | """Called when a node's progress is updated""" |
| | pass |
| |
|
| | def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): |
| | """Called when a node finishes processing""" |
| | pass |
| |
|
| | def reset(self): |
| | """Called when the progress registry is reset""" |
| | pass |
| |
|
| | def enable(self): |
| | """Enable this handler""" |
| | self.enabled = True |
| |
|
| | def disable(self): |
| | """Disable this handler""" |
| | self.enabled = False |
| |
|
| |
|
| | class CLIProgressHandler(ProgressHandler): |
| | """ |
| | Handler that displays progress using tqdm progress bars in the CLI. |
| | """ |
| |
|
| | def __init__(self): |
| | super().__init__("cli") |
| | self.progress_bars: Dict[str, tqdm] = {} |
| |
|
| | @override |
| | def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): |
| | |
| | if node_id not in self.progress_bars: |
| | self.progress_bars[node_id] = tqdm( |
| | total=state["max"], |
| | desc=f"Node {node_id}", |
| | unit="steps", |
| | leave=True, |
| | position=len(self.progress_bars), |
| | ) |
| |
|
| | @override |
| | def update_handler( |
| | self, |
| | node_id: str, |
| | value: float, |
| | max_value: float, |
| | state: NodeProgressState, |
| | prompt_id: str, |
| | image: PreviewImageTuple | None = None, |
| | ): |
| | |
| | if node_id not in self.progress_bars: |
| | self.progress_bars[node_id] = tqdm( |
| | total=max_value, |
| | desc=f"Node {node_id}", |
| | unit="steps", |
| | leave=True, |
| | position=len(self.progress_bars), |
| | ) |
| | self.progress_bars[node_id].update(value) |
| | else: |
| | |
| | if max_value != self.progress_bars[node_id].total: |
| | self.progress_bars[node_id].total = max_value |
| | |
| | current_position = self.progress_bars[node_id].n |
| | update_amount = value - current_position |
| | if update_amount > 0: |
| | self.progress_bars[node_id].update(update_amount) |
| |
|
| | @override |
| | def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): |
| | |
| | if node_id in self.progress_bars: |
| | |
| | remaining = state["max"] - self.progress_bars[node_id].n |
| | if remaining > 0: |
| | self.progress_bars[node_id].update(remaining) |
| | self.progress_bars[node_id].close() |
| | del self.progress_bars[node_id] |
| |
|
| | @override |
| | def reset(self): |
| | |
| | for bar in self.progress_bars.values(): |
| | bar.close() |
| | self.progress_bars.clear() |
| |
|
| |
|
| | class WebUIProgressHandler(ProgressHandler): |
| | """ |
| | Handler that sends progress updates to the WebUI via WebSockets. |
| | """ |
| |
|
| | def __init__(self, server_instance): |
| | super().__init__("webui") |
| | self.server_instance = server_instance |
| |
|
| | def set_registry(self, registry: "ProgressRegistry"): |
| | self.registry = registry |
| |
|
| | def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]): |
| | """Send the current progress state to the client""" |
| | if self.server_instance is None: |
| | return |
| |
|
| | |
| | active_nodes = { |
| | node_id: { |
| | "value": state["value"], |
| | "max": state["max"], |
| | "state": state["state"].value, |
| | "node_id": node_id, |
| | "prompt_id": prompt_id, |
| | "display_node_id": self.registry.dynprompt.get_display_node_id(node_id), |
| | "parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id), |
| | "real_node_id": self.registry.dynprompt.get_real_node_id(node_id), |
| | } |
| | for node_id, state in nodes.items() |
| | if state["state"] != NodeState.Pending |
| | } |
| |
|
| | |
| | |
| | self.server_instance.send_sync( |
| | "progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id |
| | ) |
| |
|
| | @override |
| | def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): |
| | |
| | if self.registry: |
| | self._send_progress_state(prompt_id, self.registry.nodes) |
| |
|
| | @override |
| | def update_handler( |
| | self, |
| | node_id: str, |
| | value: float, |
| | max_value: float, |
| | state: NodeProgressState, |
| | prompt_id: str, |
| | image: PreviewImageTuple | None = None, |
| | ): |
| | |
| | if self.registry: |
| | self._send_progress_state(prompt_id, self.registry.nodes) |
| | if image: |
| | |
| | if feature_flags.supports_feature( |
| | self.server_instance.sockets_metadata, |
| | self.server_instance.client_id, |
| | "supports_preview_metadata", |
| | ): |
| | metadata = { |
| | "node_id": node_id, |
| | "prompt_id": prompt_id, |
| | "display_node_id": self.registry.dynprompt.get_display_node_id( |
| | node_id |
| | ), |
| | "parent_node_id": self.registry.dynprompt.get_parent_node_id( |
| | node_id |
| | ), |
| | "real_node_id": self.registry.dynprompt.get_real_node_id(node_id), |
| | } |
| | self.server_instance.send_sync( |
| | BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, |
| | (image, metadata), |
| | self.server_instance.client_id, |
| | ) |
| |
|
| | @override |
| | def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): |
| | |
| | if self.registry: |
| | self._send_progress_state(prompt_id, self.registry.nodes) |
| |
|
| | class ProgressRegistry: |
| | """ |
| | Registry that maintains node progress state and notifies registered handlers. |
| | """ |
| |
|
| | def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"): |
| | self.prompt_id = prompt_id |
| | self.dynprompt = dynprompt |
| | self.nodes: Dict[str, NodeProgressState] = {} |
| | self.handlers: Dict[str, ProgressHandler] = {} |
| |
|
| | def register_handler(self, handler: ProgressHandler) -> None: |
| | """Register a progress handler""" |
| | self.handlers[handler.name] = handler |
| |
|
| | def unregister_handler(self, handler_name: str) -> None: |
| | """Unregister a progress handler""" |
| | if handler_name in self.handlers: |
| | |
| | self.handlers[handler_name].reset() |
| | del self.handlers[handler_name] |
| |
|
| | def enable_handler(self, handler_name: str) -> None: |
| | """Enable a progress handler""" |
| | if handler_name in self.handlers: |
| | self.handlers[handler_name].enable() |
| |
|
| | def disable_handler(self, handler_name: str) -> None: |
| | """Disable a progress handler""" |
| | if handler_name in self.handlers: |
| | self.handlers[handler_name].disable() |
| |
|
| | def ensure_entry(self, node_id: str) -> NodeProgressState: |
| | """Ensure a node entry exists""" |
| | if node_id not in self.nodes: |
| | self.nodes[node_id] = NodeProgressState( |
| | state=NodeState.Pending, value=0, max=1 |
| | ) |
| | return self.nodes[node_id] |
| |
|
| | def start_progress(self, node_id: str) -> None: |
| | """Start progress tracking for a node""" |
| | entry = self.ensure_entry(node_id) |
| | entry["state"] = NodeState.Running |
| | entry["value"] = 0.0 |
| | entry["max"] = 1.0 |
| |
|
| | |
| | for handler in self.handlers.values(): |
| | if handler.enabled: |
| | handler.start_handler(node_id, entry, self.prompt_id) |
| |
|
| | def update_progress( |
| | self, node_id: str, value: float, max_value: float, image: PreviewImageTuple | None = None |
| | ) -> None: |
| | """Update progress for a node""" |
| | entry = self.ensure_entry(node_id) |
| | entry["state"] = NodeState.Running |
| | entry["value"] = value |
| | entry["max"] = max_value |
| |
|
| | |
| | for handler in self.handlers.values(): |
| | if handler.enabled: |
| | handler.update_handler( |
| | node_id, value, max_value, entry, self.prompt_id, image |
| | ) |
| |
|
| | def finish_progress(self, node_id: str) -> None: |
| | """Finish progress tracking for a node""" |
| | entry = self.ensure_entry(node_id) |
| | entry["state"] = NodeState.Finished |
| | entry["value"] = entry["max"] |
| |
|
| | |
| | for handler in self.handlers.values(): |
| | if handler.enabled: |
| | handler.finish_handler(node_id, entry, self.prompt_id) |
| |
|
| | def reset_handlers(self) -> None: |
| | """Reset all handlers""" |
| | for handler in self.handlers.values(): |
| | handler.reset() |
| |
|
| | |
| | global_progress_registry: ProgressRegistry | None = None |
| |
|
| | def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None: |
| | global global_progress_registry |
| |
|
| | |
| | if global_progress_registry is not None: |
| | global_progress_registry.reset_handlers() |
| |
|
| | |
| | global_progress_registry = ProgressRegistry(prompt_id, dynprompt) |
| |
|
| |
|
| | def add_progress_handler(handler: ProgressHandler) -> None: |
| | registry = get_progress_state() |
| | handler.set_registry(registry) |
| | registry.register_handler(handler) |
| |
|
| |
|
| | def get_progress_state() -> ProgressRegistry: |
| | global global_progress_registry |
| | if global_progress_registry is None: |
| | from comfy_execution.graph import DynamicPrompt |
| |
|
| | global_progress_registry = ProgressRegistry( |
| | prompt_id="", dynprompt=DynamicPrompt({}) |
| | ) |
| | return global_progress_registry |
| |
|