workflow_arena / generator.py
Cyber-Machine's picture
init: WorkFlowArena
aea0016 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Seeded workflow DAG generator and derived static metrics for WorkflowArena."""
from __future__ import annotations
import random
from workflow_arena.models import (
EpisodeConfig,
TaskStatus,
WorkflowEnvStateSnapshot,
WorkflowEpisodeSpec,
WorkflowTaskSpec,
)
from workflow_arena.presets import get_preset_config
def _task_id(index: int) -> str:
return f"task_{index:02d}"
def _compute_earliest_start(task_map: dict[str, WorkflowTaskSpec], task_id: str) -> int:
task = task_map[task_id]
if not task.dependencies:
return 0
return max(
_compute_earliest_start(task_map, dep_id) + task_map[dep_id].duration
for dep_id in task.dependencies
)
def _compute_critical_path(task_map: dict[str, WorkflowTaskSpec], task_id: str) -> int:
task = task_map[task_id]
if not task.dependents:
return task.duration
return task.duration + max(
_compute_critical_path(task_map, child_id) for child_id in task.dependents
)
def _compute_downstream_count(
task_map: dict[str, WorkflowTaskSpec], task_id: str, seen: set[str] | None = None
) -> int:
task = task_map[task_id]
local_seen = set() if seen is None else seen
count = 0
for child_id in task.dependents:
if child_id in local_seen:
continue
local_seen.add(child_id)
count += 1 + _compute_downstream_count(task_map, child_id, local_seen)
return count
def _estimate_deadline(
task: WorkflowTaskSpec,
workflow_critical_path: int,
rng: random.Random,
tightness: float,
) -> int:
slack_allowance = max(1, int(round((workflow_critical_path - task.earliest_start) * (1.15 - tightness))))
jitter = rng.randint(0, max(1, task.duration // 2))
return task.earliest_start + task.duration + slack_allowance + jitter
def generate_episode(
config: EpisodeConfig,
) -> tuple[WorkflowEpisodeSpec, WorkflowEnvStateSnapshot]:
"""Generate a deterministic workflow episode from a preset and seed."""
preset_config = get_preset_config(config.preset)
worker_count = config.worker_count or preset_config.worker_count
resolved_config = config.model_copy(update={"worker_count": worker_count})
rng = random.Random(resolved_config.seed)
task_count = rng.randint(preset_config.min_tasks, preset_config.max_tasks)
dependency_map: dict[str, list[str]] = {}
dependent_map: dict[str, list[str]] = {}
task_ids = [_task_id(index + 1) for index in range(task_count)]
for index, task_id in enumerate(task_ids):
candidates = task_ids[:index]
dependencies: list[str] = []
if candidates:
for candidate in candidates:
if rng.random() < preset_config.edge_probability:
dependencies.append(candidate)
if not dependencies and index > 0 and rng.random() < 0.6:
dependencies.append(rng.choice(candidates))
dependency_map[task_id] = sorted(set(dependencies), key=task_ids.index)
dependent_map[task_id] = []
for task_id, dependencies in dependency_map.items():
for dependency in dependencies:
dependent_map[dependency].append(task_id)
tasks = [
WorkflowTaskSpec(
task_id=task_id,
duration=rng.randint(preset_config.duration_min, preset_config.duration_max),
priority=rng.randint(preset_config.priority_min, preset_config.priority_max),
dependencies=dependency_map[task_id],
dependents=sorted(dependent_map[task_id], key=task_ids.index),
deadline=None,
)
for task_id in task_ids
]
task_map = {task.task_id: task for task in tasks}
workflow_critical_path = 0
for task in tasks:
task.earliest_start = _compute_earliest_start(task_map, task.task_id)
task.critical_path_length = _compute_critical_path(task_map, task.task_id)
task.downstream_count = _compute_downstream_count(task_map, task.task_id)
workflow_critical_path = max(
workflow_critical_path, task.earliest_start + task.duration
)
workflow_critical_path = max(
workflow_critical_path,
max(task.critical_path_length for task in tasks),
)
max_downstream = max(task.downstream_count for task in tasks) if tasks else 1
max_critical_path = max(task.critical_path_length for task in tasks) if tasks else 1
for task in tasks:
latest_start = max(
task.earliest_start, workflow_critical_path - task.critical_path_length
)
task.slack = max(0, latest_start - task.earliest_start)
task.criticality = round(
0.7 * (task.critical_path_length / max_critical_path)
+ 0.3 * (task.downstream_count / max(1, max_downstream)),
4,
)
task.deadline = _estimate_deadline(
task=task,
workflow_critical_path=workflow_critical_path,
rng=rng,
tightness=preset_config.deadline_tightness,
)
episode = WorkflowEpisodeSpec(
config=resolved_config,
preset_config=preset_config,
tasks=tasks,
)
ready_task_ids = [task.task_id for task in tasks if not task.dependencies]
blocked_task_ids = [task.task_id for task in tasks if task.dependencies]
state = WorkflowEnvStateSnapshot(
episode_id=f"seed-{resolved_config.seed}",
current_time=0,
task_statuses={
task.task_id: (
TaskStatus.READY if not task.dependencies else TaskStatus.BLOCKED
)
for task in tasks
},
running_task_ids=[],
completed_task_ids=[],
ready_task_ids=ready_task_ids,
blocked_task_ids=blocked_task_ids,
task_start_times={},
task_end_times={},
task_remaining_dependencies={
task.task_id: len(task.dependencies) for task in tasks
},
task_assigned_finish_times={},
task_attempt_counts={task.task_id: 0 for task in tasks},
cumulative_busy_time=0,
time_budget=None,
degraded_workers=0,
active_worker_outage_until=None,
recent_failure_events=[],
)
return episode, state