ImageGen / core /pipelines /workflow_executor.py
RioShiina's picture
Upload folder using huggingface_hub
dc0cea5 verified
Raw
History Blame Contribute Delete
5.38 kB
import torch
from collections import defaultdict, deque
from typing import Dict, Any, List
from comfy_integration.nodes import NODE_CLASS_MAPPINGS
from utils.app_utils import get_value_at_index
class WorkflowExecutor:
@staticmethod
def topological_sort(workflow: Dict[str, Any]) -> List[str]:
graph = defaultdict(list)
in_degree = {node_id: 0 for node_id in workflow}
for node_id, node_info in workflow.items():
for input_value in node_info.get('inputs', {}).values():
if isinstance(input_value, list) and len(input_value) == 2 and isinstance(input_value[0], str):
source_node_id = input_value[0]
if source_node_id in workflow:
graph[source_node_id].append(node_id)
in_degree[node_id] += 1
queue = deque([node_id for node_id, degree in in_degree.items() if degree == 0])
sorted_nodes = []
while queue:
current_node_id = queue.popleft()
sorted_nodes.append(current_node_id)
for neighbor_node_id in graph[current_node_id]:
in_degree[neighbor_node_id] -= 1
if in_degree[neighbor_node_id] == 0:
queue.append(neighbor_node_id)
if len(sorted_nodes) != len(workflow):
raise RuntimeError("Workflow contains a cycle and cannot be executed.")
return sorted_nodes
@staticmethod
def execute_workflow(workflow: Dict[str, Any], initial_objects: Dict[str, Any]):
with torch.no_grad():
computed_outputs = initial_objects
try:
sorted_node_ids = WorkflowExecutor.topological_sort(workflow)
print(f"--- [Workflow Executor] Execution order: {sorted_node_ids}")
except RuntimeError as e:
print("--- [Workflow Executor] ERROR: Failed to sort workflow. Dumping graph details. ---")
for node_id, node_info in workflow.items():
print(f" Node {node_id} ({node_info['class_type']}):")
for input_name, input_value in node_info['inputs'].items():
if isinstance(input_value, list) and len(input_value) == 2 and isinstance(input_value[0], str):
print(f" - {input_name} <- [{input_value[0]}, {input_value[1]}]")
raise e
for node_id in sorted_node_ids:
if node_id in computed_outputs:
continue
node_info = workflow[node_id]
class_type = node_info['class_type']
is_loader_with_filename = 'Loader' in class_type and any(key.endswith('_name') for key in node_info['inputs'])
if node_id in initial_objects and is_loader_with_filename:
continue
node_class = NODE_CLASS_MAPPINGS.get(class_type)
if node_class is None:
raise RuntimeError(f"Could not find node class '{class_type}'. Is it imported in comfy_integration/nodes.py?")
node_instance = node_class()
kwargs = {}
for param_name, param_value in node_info['inputs'].items():
if isinstance(param_value, list) and len(param_value) == 2 and isinstance(param_value[0], str):
source_node_id, output_index = param_value
if source_node_id not in computed_outputs:
raise RuntimeError(f"Workflow integrity error: Output of node {source_node_id} needed for {node_id} but not yet computed.")
source_output_tuple = computed_outputs[source_node_id]
actual_value = get_value_at_index(source_output_tuple, output_index)
else:
actual_value = param_value
if '.' in param_name:
parent_key, child_key = param_name.split('.', 1)
if parent_key not in kwargs or not isinstance(kwargs[parent_key], dict):
kwargs[parent_key] = {}
kwargs[parent_key][child_key] = actual_value
else:
kwargs[param_name] = actual_value
function_name = getattr(node_class, 'FUNCTION')
execution_method = getattr(node_instance, function_name)
result = execution_method(**kwargs)
computed_outputs[node_id] = result
final_node_id = None
for node_id in reversed(sorted_node_ids):
if workflow[node_id]['class_type'] == 'SaveImage':
final_node_id = node_id
break
if not final_node_id:
raise RuntimeError("Workflow does not contain a 'SaveImage' node as the output.")
save_image_inputs = workflow[final_node_id]['inputs']
image_source_node_id, image_source_index = save_image_inputs['images']
return get_value_at_index(computed_outputs[image_source_node_id], image_source_index)