File size: 5,379 Bytes
dc0cea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
from collections import defaultdict, deque
from typing import Dict, Any, List
from comfy_integration.nodes import NODE_CLASS_MAPPINGS
from utils.app_utils import get_value_at_index

class WorkflowExecutor:
    @staticmethod
    def topological_sort(workflow: Dict[str, Any]) -> List[str]:
        graph = defaultdict(list)
        in_degree = {node_id: 0 for node_id in workflow}

        for node_id, node_info in workflow.items():
            for input_value in node_info.get('inputs', {}).values():
                if isinstance(input_value, list) and len(input_value) == 2 and isinstance(input_value[0], str):
                    source_node_id = input_value[0]
                    if source_node_id in workflow:
                        graph[source_node_id].append(node_id)
                        in_degree[node_id] += 1

        queue = deque([node_id for node_id, degree in in_degree.items() if degree == 0])
        
        sorted_nodes = []
        while queue:
            current_node_id = queue.popleft()
            sorted_nodes.append(current_node_id)

            for neighbor_node_id in graph[current_node_id]:
                in_degree[neighbor_node_id] -= 1
                if in_degree[neighbor_node_id] == 0:
                    queue.append(neighbor_node_id)
        
        if len(sorted_nodes) != len(workflow):
            raise RuntimeError("Workflow contains a cycle and cannot be executed.")
            
        return sorted_nodes

    @staticmethod
    def execute_workflow(workflow: Dict[str, Any], initial_objects: Dict[str, Any]):
        with torch.no_grad():
            computed_outputs = initial_objects
            
            try:
                sorted_node_ids = WorkflowExecutor.topological_sort(workflow)
                print(f"--- [Workflow Executor] Execution order: {sorted_node_ids}")
            except RuntimeError as e:
                print("--- [Workflow Executor] ERROR: Failed to sort workflow. Dumping graph details. ---")
                for node_id, node_info in workflow.items():
                    print(f"  Node {node_id} ({node_info['class_type']}):")
                    for input_name, input_value in node_info['inputs'].items():
                         if isinstance(input_value, list) and len(input_value) == 2 and isinstance(input_value[0], str):
                             print(f"    - {input_name} <- [{input_value[0]}, {input_value[1]}]")
                raise e

            for node_id in sorted_node_ids:
                if node_id in computed_outputs:
                    continue
                    
                node_info = workflow[node_id]
                class_type = node_info['class_type']
                
                is_loader_with_filename = 'Loader' in class_type and any(key.endswith('_name') for key in node_info['inputs'])
                if node_id in initial_objects and is_loader_with_filename:
                    continue
                
                node_class = NODE_CLASS_MAPPINGS.get(class_type)
                if node_class is None:
                     raise RuntimeError(f"Could not find node class '{class_type}'. Is it imported in comfy_integration/nodes.py?")
                
                node_instance = node_class()
                
                kwargs = {}
                for param_name, param_value in node_info['inputs'].items():
                    if isinstance(param_value, list) and len(param_value) == 2 and isinstance(param_value[0], str):
                        source_node_id, output_index = param_value
                        if source_node_id not in computed_outputs:
                            raise RuntimeError(f"Workflow integrity error: Output of node {source_node_id} needed for {node_id} but not yet computed.")
                        
                        source_output_tuple = computed_outputs[source_node_id]
                        actual_value = get_value_at_index(source_output_tuple, output_index)
                    else:
                        actual_value = param_value

                    if '.' in param_name:
                        parent_key, child_key = param_name.split('.', 1)
                        if parent_key not in kwargs or not isinstance(kwargs[parent_key], dict):
                            kwargs[parent_key] = {}
                        kwargs[parent_key][child_key] = actual_value
                    else:
                        kwargs[param_name] = actual_value

                function_name = getattr(node_class, 'FUNCTION')
                execution_method = getattr(node_instance, function_name)
                
                result = execution_method(**kwargs)
                computed_outputs[node_id] = result
            
            final_node_id = None
            for node_id in reversed(sorted_node_ids):
                 if workflow[node_id]['class_type'] == 'SaveImage':
                     final_node_id = node_id
                     break
            
            if not final_node_id:
                raise RuntimeError("Workflow does not contain a 'SaveImage' node as the output.")

            save_image_inputs = workflow[final_node_id]['inputs']
            image_source_node_id, image_source_index = save_image_inputs['images']
            
            return get_value_at_index(computed_outputs[image_source_node_id], image_source_index)