Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Seeded workflow DAG generator and derived static metrics for WorkflowArena.""" | |
| from __future__ import annotations | |
| import random | |
| from workflow_arena.models import ( | |
| EpisodeConfig, | |
| TaskStatus, | |
| WorkflowEnvStateSnapshot, | |
| WorkflowEpisodeSpec, | |
| WorkflowTaskSpec, | |
| ) | |
| from workflow_arena.presets import get_preset_config | |
| def _task_id(index: int) -> str: | |
| return f"task_{index:02d}" | |
| def _compute_earliest_start(task_map: dict[str, WorkflowTaskSpec], task_id: str) -> int: | |
| task = task_map[task_id] | |
| if not task.dependencies: | |
| return 0 | |
| return max( | |
| _compute_earliest_start(task_map, dep_id) + task_map[dep_id].duration | |
| for dep_id in task.dependencies | |
| ) | |
| def _compute_critical_path(task_map: dict[str, WorkflowTaskSpec], task_id: str) -> int: | |
| task = task_map[task_id] | |
| if not task.dependents: | |
| return task.duration | |
| return task.duration + max( | |
| _compute_critical_path(task_map, child_id) for child_id in task.dependents | |
| ) | |
| def _compute_downstream_count( | |
| task_map: dict[str, WorkflowTaskSpec], task_id: str, seen: set[str] | None = None | |
| ) -> int: | |
| task = task_map[task_id] | |
| local_seen = set() if seen is None else seen | |
| count = 0 | |
| for child_id in task.dependents: | |
| if child_id in local_seen: | |
| continue | |
| local_seen.add(child_id) | |
| count += 1 + _compute_downstream_count(task_map, child_id, local_seen) | |
| return count | |
| def _estimate_deadline( | |
| task: WorkflowTaskSpec, | |
| workflow_critical_path: int, | |
| rng: random.Random, | |
| tightness: float, | |
| ) -> int: | |
| slack_allowance = max(1, int(round((workflow_critical_path - task.earliest_start) * (1.15 - tightness)))) | |
| jitter = rng.randint(0, max(1, task.duration // 2)) | |
| return task.earliest_start + task.duration + slack_allowance + jitter | |
| def generate_episode( | |
| config: EpisodeConfig, | |
| ) -> tuple[WorkflowEpisodeSpec, WorkflowEnvStateSnapshot]: | |
| """Generate a deterministic workflow episode from a preset and seed.""" | |
| preset_config = get_preset_config(config.preset) | |
| worker_count = config.worker_count or preset_config.worker_count | |
| resolved_config = config.model_copy(update={"worker_count": worker_count}) | |
| rng = random.Random(resolved_config.seed) | |
| task_count = rng.randint(preset_config.min_tasks, preset_config.max_tasks) | |
| dependency_map: dict[str, list[str]] = {} | |
| dependent_map: dict[str, list[str]] = {} | |
| task_ids = [_task_id(index + 1) for index in range(task_count)] | |
| for index, task_id in enumerate(task_ids): | |
| candidates = task_ids[:index] | |
| dependencies: list[str] = [] | |
| if candidates: | |
| for candidate in candidates: | |
| if rng.random() < preset_config.edge_probability: | |
| dependencies.append(candidate) | |
| if not dependencies and index > 0 and rng.random() < 0.6: | |
| dependencies.append(rng.choice(candidates)) | |
| dependency_map[task_id] = sorted(set(dependencies), key=task_ids.index) | |
| dependent_map[task_id] = [] | |
| for task_id, dependencies in dependency_map.items(): | |
| for dependency in dependencies: | |
| dependent_map[dependency].append(task_id) | |
| tasks = [ | |
| WorkflowTaskSpec( | |
| task_id=task_id, | |
| duration=rng.randint(preset_config.duration_min, preset_config.duration_max), | |
| priority=rng.randint(preset_config.priority_min, preset_config.priority_max), | |
| dependencies=dependency_map[task_id], | |
| dependents=sorted(dependent_map[task_id], key=task_ids.index), | |
| deadline=None, | |
| ) | |
| for task_id in task_ids | |
| ] | |
| task_map = {task.task_id: task for task in tasks} | |
| workflow_critical_path = 0 | |
| for task in tasks: | |
| task.earliest_start = _compute_earliest_start(task_map, task.task_id) | |
| task.critical_path_length = _compute_critical_path(task_map, task.task_id) | |
| task.downstream_count = _compute_downstream_count(task_map, task.task_id) | |
| workflow_critical_path = max( | |
| workflow_critical_path, task.earliest_start + task.duration | |
| ) | |
| workflow_critical_path = max( | |
| workflow_critical_path, | |
| max(task.critical_path_length for task in tasks), | |
| ) | |
| max_downstream = max(task.downstream_count for task in tasks) if tasks else 1 | |
| max_critical_path = max(task.critical_path_length for task in tasks) if tasks else 1 | |
| for task in tasks: | |
| latest_start = max( | |
| task.earliest_start, workflow_critical_path - task.critical_path_length | |
| ) | |
| task.slack = max(0, latest_start - task.earliest_start) | |
| task.criticality = round( | |
| 0.7 * (task.critical_path_length / max_critical_path) | |
| + 0.3 * (task.downstream_count / max(1, max_downstream)), | |
| 4, | |
| ) | |
| task.deadline = _estimate_deadline( | |
| task=task, | |
| workflow_critical_path=workflow_critical_path, | |
| rng=rng, | |
| tightness=preset_config.deadline_tightness, | |
| ) | |
| episode = WorkflowEpisodeSpec( | |
| config=resolved_config, | |
| preset_config=preset_config, | |
| tasks=tasks, | |
| ) | |
| ready_task_ids = [task.task_id for task in tasks if not task.dependencies] | |
| blocked_task_ids = [task.task_id for task in tasks if task.dependencies] | |
| state = WorkflowEnvStateSnapshot( | |
| episode_id=f"seed-{resolved_config.seed}", | |
| current_time=0, | |
| task_statuses={ | |
| task.task_id: ( | |
| TaskStatus.READY if not task.dependencies else TaskStatus.BLOCKED | |
| ) | |
| for task in tasks | |
| }, | |
| running_task_ids=[], | |
| completed_task_ids=[], | |
| ready_task_ids=ready_task_ids, | |
| blocked_task_ids=blocked_task_ids, | |
| task_start_times={}, | |
| task_end_times={}, | |
| task_remaining_dependencies={ | |
| task.task_id: len(task.dependencies) for task in tasks | |
| }, | |
| task_assigned_finish_times={}, | |
| task_attempt_counts={task.task_id: 0 for task in tasks}, | |
| cumulative_busy_time=0, | |
| time_budget=None, | |
| degraded_workers=0, | |
| active_worker_outage_until=None, | |
| recent_failure_events=[], | |
| ) | |
| return episode, state | |