| import torch |
| from collections import defaultdict, deque |
| from typing import Dict, Any, List |
| from comfy_integration.nodes import NODE_CLASS_MAPPINGS |
| from utils.app_utils import get_value_at_index |
|
|
| class WorkflowExecutor: |
| @staticmethod |
| def topological_sort(workflow: Dict[str, Any]) -> List[str]: |
| graph = defaultdict(list) |
| in_degree = {node_id: 0 for node_id in workflow} |
|
|
| for node_id, node_info in workflow.items(): |
| for input_value in node_info.get('inputs', {}).values(): |
| if isinstance(input_value, list) and len(input_value) == 2 and isinstance(input_value[0], str): |
| source_node_id = input_value[0] |
| if source_node_id in workflow: |
| graph[source_node_id].append(node_id) |
| in_degree[node_id] += 1 |
|
|
| queue = deque([node_id for node_id, degree in in_degree.items() if degree == 0]) |
| |
| sorted_nodes = [] |
| while queue: |
| current_node_id = queue.popleft() |
| sorted_nodes.append(current_node_id) |
|
|
| for neighbor_node_id in graph[current_node_id]: |
| in_degree[neighbor_node_id] -= 1 |
| if in_degree[neighbor_node_id] == 0: |
| queue.append(neighbor_node_id) |
| |
| if len(sorted_nodes) != len(workflow): |
| raise RuntimeError("Workflow contains a cycle and cannot be executed.") |
| |
| return sorted_nodes |
|
|
| @staticmethod |
| def execute_workflow(workflow: Dict[str, Any], initial_objects: Dict[str, Any]): |
| with torch.no_grad(): |
| computed_outputs = initial_objects |
| |
| try: |
| sorted_node_ids = WorkflowExecutor.topological_sort(workflow) |
| print(f"--- [Workflow Executor] Execution order: {sorted_node_ids}") |
| except RuntimeError as e: |
| print("--- [Workflow Executor] ERROR: Failed to sort workflow. Dumping graph details. ---") |
| for node_id, node_info in workflow.items(): |
| print(f" Node {node_id} ({node_info['class_type']}):") |
| for input_name, input_value in node_info['inputs'].items(): |
| if isinstance(input_value, list) and len(input_value) == 2 and isinstance(input_value[0], str): |
| print(f" - {input_name} <- [{input_value[0]}, {input_value[1]}]") |
| raise e |
|
|
| for node_id in sorted_node_ids: |
| if node_id in computed_outputs: |
| continue |
| |
| node_info = workflow[node_id] |
| class_type = node_info['class_type'] |
| |
| is_loader_with_filename = 'Loader' in class_type and any(key.endswith('_name') for key in node_info['inputs']) |
| if node_id in initial_objects and is_loader_with_filename: |
| continue |
| |
| node_class = NODE_CLASS_MAPPINGS.get(class_type) |
| if node_class is None: |
| raise RuntimeError(f"Could not find node class '{class_type}'. Is it imported in comfy_integration/nodes.py?") |
| |
| node_instance = node_class() |
| |
| kwargs = {} |
| for param_name, param_value in node_info['inputs'].items(): |
| if isinstance(param_value, list) and len(param_value) == 2 and isinstance(param_value[0], str): |
| source_node_id, output_index = param_value |
| if source_node_id not in computed_outputs: |
| raise RuntimeError(f"Workflow integrity error: Output of node {source_node_id} needed for {node_id} but not yet computed.") |
| |
| source_output_tuple = computed_outputs[source_node_id] |
| actual_value = get_value_at_index(source_output_tuple, output_index) |
| else: |
| actual_value = param_value |
|
|
| if '.' in param_name: |
| parent_key, child_key = param_name.split('.', 1) |
| if parent_key not in kwargs or not isinstance(kwargs[parent_key], dict): |
| kwargs[parent_key] = {} |
| kwargs[parent_key][child_key] = actual_value |
| else: |
| kwargs[param_name] = actual_value |
|
|
| function_name = getattr(node_class, 'FUNCTION') |
| execution_method = getattr(node_instance, function_name) |
| |
| result = execution_method(**kwargs) |
| computed_outputs[node_id] = result |
| |
| final_node_id = None |
| for node_id in reversed(sorted_node_ids): |
| if workflow[node_id]['class_type'] == 'SaveImage': |
| final_node_id = node_id |
| break |
| |
| if not final_node_id: |
| raise RuntimeError("Workflow does not contain a 'SaveImage' node as the output.") |
|
|
| save_image_inputs = workflow[final_node_id]['inputs'] |
| image_source_node_id, image_source_index = save_image_inputs['images'] |
| |
| return get_value_at_index(computed_outputs[image_source_node_id], image_source_index) |
|
|