File size: 16,476 Bytes
aea0016
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Typed models for WorkflowArena.
"""

from __future__ import annotations

from enum import Enum

from openenv.core.env_server.types import Action, Observation
from pydantic import BaseModel, Field


class TaskStatus(str, Enum):
    """Allowed lifecycle states for a workflow task."""

    BLOCKED = "blocked"
    READY = "ready"
    RUNNING = "running"
    COMPLETED = "completed"


class DifficultyPreset(str, Enum):
    """Initial task presets required by the hackathon."""

    EASY = "easy"
    MEDIUM = "medium"
    HARD = "hard"


class WorkflowActionType(str, Enum):
    """Explicit action space for the scheduler agent."""

    DISPATCH = "dispatch"
    WAIT = "wait"


class RewardBreakdown(BaseModel):
    """Named reward channels for shaped feedback."""

    completion_reward: float = Field(
        default=0.0, description="Reward for completing tasks."
    )
    utilization_reward: float = Field(
        default=0.0, description="Reward for keeping workers busy."
    )
    deadline_reward: float = Field(
        default=0.0, description="Reward or penalty tied to deadlines."
    )
    criticality_reward: float = Field(
        default=0.0,
        description="Reward for prioritizing critical-path work appropriately.",
    )
    idle_penalty: float = Field(
        default=0.0, description="Penalty for leaving workers idle."
    )
    invalid_action_penalty: float = Field(
        default=0.0,
        description="Penalty for malformed or infeasible actions.",
    )
    terminal_makespan_score: float = Field(
        default=0.0,
        description="Terminal score based on final schedule quality.",
    )
    unfinished_task_penalty: float = Field(
        default=0.0,
        description="Terminal penalty for unfinished work at episode end.",
    )


class FailureEventType(str, Enum):
    """Failure events surfaced to agents and the UI."""

    WORKER_OUTAGE_START = "worker_outage_start"
    WORKER_OUTAGE_END = "worker_outage_end"
    TASK_RETRY_FAILURE = "task_retry_failure"


class WorkflowFailureEvent(BaseModel):
    """Structured failure event emitted by the environment."""

    event_type: FailureEventType = Field(..., description="Failure category.")
    time: int = Field(..., ge=0, description="Simulated time when the event was observed.")
    task_id: str | None = Field(default=None, description="Task affected by the event, if any.")
    worker_delta: int = Field(default=0, description="Net temporary change in usable workers.")
    duration: int | None = Field(default=None, ge=0, description="Outage duration when applicable.")
    detail: str = Field(default="", description="Short human-readable summary.")


class WorkflowTaskView(BaseModel):
    """Compact task payload used in observations and the future UI."""

    task_id: str = Field(..., description="Stable task identifier.")
    status: TaskStatus = Field(..., description="Current task lifecycle state.")
    duration: int = Field(
        ..., ge=1, description="Task runtime in simulated time units."
    )
    priority: int = Field(..., ge=0, description="Priority weight for the task.")
    dependencies: list[str] = Field(
        default_factory=list,
        description="Upstream task ids that must complete first.",
    )
    deadline: int | None = Field(
        default=None,
        ge=0,
        description="Optional deadline in simulated time units.",
    )
    criticality: float | None = Field(
        default=None,
        description="Derived importance score from the DAG structure.",
    )
    slack: float | None = Field(
        default=None,
        description="Derived slack estimate for scheduling decisions.",
    )
    downstream_count: int = Field(
        default=0,
        ge=0,
        description="Count of downstream dependents reachable from this task.",
    )
    start_time: int | None = Field(
        default=None,
        ge=0,
        description="Simulated start time if the task is running or completed.",
    )
    end_time: int | None = Field(
        default=None,
        ge=0,
        description="Simulated end time if the task is completed or scheduled to finish.",
    )
    attempt_count: int = Field(
        default=0,
        ge=0,
        description="Number of retry attempts already consumed by this task.",
    )


class WorkflowTaskSpec(BaseModel):
    """Static task specification generated at episode reset."""

    task_id: str = Field(..., description="Stable task identifier.")
    duration: int = Field(..., ge=1, description="Task runtime in simulated time units.")
    priority: int = Field(..., ge=0, description="Priority weight for the task.")
    dependencies: list[str] = Field(
        default_factory=list,
        description="Upstream task ids that must complete first.",
    )
    dependents: list[str] = Field(
        default_factory=list,
        description="Downstream task ids that depend on this task.",
    )
    deadline: int | None = Field(
        default=None,
        ge=0,
        description="Optional deadline in simulated time units.",
    )
    downstream_count: int = Field(
        default=0,
        ge=0,
        description="Number of downstream tasks reachable from this node.",
    )
    critical_path_length: int = Field(
        default=0,
        ge=0,
        description="Duration-weighted path length from this task to a sink.",
    )
    earliest_start: int = Field(
        default=0,
        ge=0,
        description="Earliest feasible start time under dependency constraints.",
    )
    slack: int = Field(
        default=0,
        ge=0,
        description="Scheduling slack measured in simulated time units.",
    )
    criticality: float = Field(
        default=0.0,
        description="Normalized importance score derived from critical path and downstream impact.",
    )


class ProgressSummary(BaseModel):
    """Counts by task lifecycle state."""

    total: int = Field(default=0, ge=0)
    blocked: int = Field(default=0, ge=0)
    ready: int = Field(default=0, ge=0)
    running: int = Field(default=0, ge=0)
    completed: int = Field(default=0, ge=0)


class EpisodeConfig(BaseModel):
    """Reset-time knobs that define the episode."""

    preset: DifficultyPreset = Field(
        default=DifficultyPreset.EASY,
        description="Difficulty preset for the episode generator.",
    )
    seed: int = Field(
        default=0, description="Seed for deterministic episode generation."
    )
    worker_count: int = Field(
        default=2,
        ge=1,
        description="Number of identical workers available to the scheduler.",
    )


class GraderTarget(BaseModel):
    """High-level target bands for each preset's grader."""

    description: str = Field(..., description="What good performance means for the preset.")
    score_band_hint: str = Field(..., description="Human-readable interpretation of scores.")


class DifficultyPresetConfig(BaseModel):
    """Concrete generator knobs for a preset."""

    preset: DifficultyPreset = Field(..., description="Preset identifier.")
    min_tasks: int = Field(..., ge=2)
    max_tasks: int = Field(..., ge=2)
    edge_probability: float = Field(..., ge=0.0, le=1.0)
    duration_min: int = Field(..., ge=1)
    duration_max: int = Field(..., ge=1)
    priority_min: int = Field(..., ge=0)
    priority_max: int = Field(..., ge=0)
    worker_count: int = Field(..., ge=1)
    deadline_tightness: float = Field(
        ...,
        ge=0.0,
        description="Larger values mean tighter deadlines.",
    )
    time_budget_multiplier: float | None = Field(
        default=None,
        gt=0.0,
        description="Optional multiplier over the theoretical lower-bound makespan.",
    )
    worker_outage_rate: float = Field(
        default=0.0,
        ge=0.0,
        le=1.0,
        description="Chance of a hard-mode worker outage being sampled on a wait transition.",
    )
    worker_outage_duration_min: int = Field(
        default=0,
        ge=0,
        description="Minimum outage duration in simulated time units.",
    )
    worker_outage_duration_max: int = Field(
        default=0,
        ge=0,
        description="Maximum outage duration in simulated time units.",
    )
    task_retry_failure_rate: float = Field(
        default=0.0,
        ge=0.0,
        le=1.0,
        description="Chance that a hard-mode task completion becomes a retry failure.",
    )
    max_task_retries: int = Field(
        default=0,
        ge=0,
        description="Maximum number of retry failures a task may suffer before it must complete.",
    )
    grader_target: GraderTarget = Field(
        ...,
        description="Preset-specific grader interpretation.",
    )


class WorkflowEpisodeSpec(BaseModel):
    """Static episode description produced by the generator."""

    config: EpisodeConfig = Field(..., description="Reset-time configuration.")
    preset_config: DifficultyPresetConfig = Field(..., description="Resolved preset parameters.")
    tasks: list[WorkflowTaskSpec] = Field(..., description="Generated workflow tasks.")


class WorkflowEnvStateSnapshot(BaseModel):
    """Serializable environment state for the current episode."""

    episode_id: str = Field(..., description="Stable current episode identifier.")
    current_time: int = Field(default=0, ge=0, description="Current simulated time.")
    task_statuses: dict[str, TaskStatus] = Field(
        default_factory=dict,
        description="Current task status by task id.",
    )
    running_task_ids: list[str] = Field(
        default_factory=list,
        description="Tasks currently consuming workers.",
    )
    completed_task_ids: list[str] = Field(
        default_factory=list,
        description="Tasks that have completed.",
    )
    ready_task_ids: list[str] = Field(
        default_factory=list,
        description="Tasks currently ready for dispatch.",
    )
    blocked_task_ids: list[str] = Field(
        default_factory=list,
        description="Tasks still blocked on dependencies.",
    )
    task_start_times: dict[str, int] = Field(
        default_factory=dict,
        description="Simulated start time by task id.",
    )
    task_end_times: dict[str, int] = Field(
        default_factory=dict,
        description="Simulated completion time by task id.",
    )
    task_remaining_dependencies: dict[str, int] = Field(
        default_factory=dict,
        description="Remaining unfinished prerequisites by task id.",
    )
    task_assigned_finish_times: dict[str, int] = Field(
        default_factory=dict,
        description="Predicted completion times for currently running tasks.",
    )
    task_attempt_counts: dict[str, int] = Field(
        default_factory=dict,
        description="Retry attempts consumed by each task.",
    )
    cumulative_busy_time: int = Field(
        default=0,
        ge=0,
        description="Aggregate worker busy time accrued so far.",
    )
    time_budget: int | None = Field(
        default=None,
        ge=0,
        description="Optional terminal time budget for the episode.",
    )
    degraded_workers: int = Field(
        default=0,
        ge=0,
        description="Workers temporarily removed from usable capacity.",
    )
    active_worker_outage_until: int | None = Field(
        default=None,
        ge=0,
        description="Time when the current worker outage expires, if any.",
    )
    recent_failure_events: list[WorkflowFailureEvent] = Field(
        default_factory=list,
        description="Failure events generated on the latest transition.",
    )


class SuccessMetrics(BaseModel):
    """Primary quality metrics used for grading and demos."""

    makespan: int | None = Field(
        default=None, description="Total simulated completion time."
    )
    worker_utilization: float | None = Field(
        default=None,
        description="Fraction of available worker time that was used.",
    )
    deadline_miss_count: int = Field(
        default=0, ge=0, description="Missed task deadlines."
    )
    unfinished_task_count: int = Field(
        default=0, ge=0, description="Tasks left incomplete at terminal time."
    )
    weighted_priority_completion: float | None = Field(
        default=None,
        description="Priority-weighted on-time completion score.",
    )
    benchmark_score: float | None = Field(
        default=None,
        description="Deterministic terminal benchmark score in the 0.0-1.0 range.",
    )


class WorkflowArenaAction(Action):
    """Strict action space for the workflow scheduler."""

    action_type: WorkflowActionType = Field(
        ...,
        description="Dispatch ready tasks or wait for the next completion event.",
    )
    task_ids: list[str] = Field(
        default_factory=list,
        description="Task ids to dispatch. Must be empty for wait().",
    )


class WorkflowArenaObservation(Observation):
    """Compact, typed observation contract for WorkflowArena."""

    instruction: str = Field(
        default=(
            "Schedule dependency-constrained workflow tasks on limited workers using "
            "dispatch(task_ids=[...]) or wait()."
        ),
        description="Short prompt shown to inference agents.",
    )
    config: EpisodeConfig = Field(
        default_factory=EpisodeConfig,
        description="Episode generation settings.",
    )
    current_time: int = Field(default=0, ge=0, description="Current simulated time.")
    total_workers: int = Field(default=2, ge=1, description="Total identical workers.")
    effective_workers: int = Field(
        default=2,
        ge=0,
        description="Usable workers after temporary degradation is applied.",
    )
    degraded_workers: int = Field(
        default=0,
        ge=0,
        description="Workers currently unavailable due to outages.",
    )
    free_workers: int = Field(default=2, ge=0, description="Currently idle workers.")
    time_budget: int | None = Field(
        default=None,
        ge=0,
        description="Optional terminal time budget for the current episode.",
    )
    time_remaining: int | None = Field(
        default=None,
        description="Remaining time until the episode budget expires, if budgeted.",
    )
    progress: ProgressSummary = Field(
        default_factory=ProgressSummary,
        description="Task counts by lifecycle state.",
    )
    ready_tasks: list[WorkflowTaskView] = Field(
        default_factory=list,
        description="Ready tasks eligible for dispatch.",
    )
    running_tasks: list[WorkflowTaskView] = Field(
        default_factory=list,
        description="Tasks currently consuming workers.",
    )
    completed_tasks: list[WorkflowTaskView] = Field(
        default_factory=list,
        description="Tasks already completed.",
    )
    blocked_tasks: list[WorkflowTaskView] = Field(
        default_factory=list,
        description="Tasks still waiting on dependencies.",
    )
    last_reward_breakdown: RewardBreakdown = Field(
        default_factory=RewardBreakdown,
        description="Per-step reward channel breakdown.",
    )
    cumulative_reward: float = Field(default=0.0, description="Running total reward.")
    success_metrics: SuccessMetrics = Field(
        default_factory=SuccessMetrics,
        description="Primary schedule quality metrics.",
    )
    note: str | None = Field(
        default=None,
        description="Short environment note about the latest transition.",
    )
    validation_error: str | None = Field(
        default=None,
        description="Explicit invalid-action explanation when the previous action failed.",
    )
    termination_reason: str | None = Field(
        default=None,
        description="Terminal reason when the episode ended unsuccessfully.",
    )
    benchmark_score: float | None = Field(
        default=None,
        description="Top-level bounded benchmark score for easier client access.",
    )
    recent_failure_events: list[WorkflowFailureEvent] = Field(
        default_factory=list,
        description="Failure events generated on the latest accepted transition.",
    )
    received_action: dict[str, object] | None = Field(
        default=None,
        description="Last action accepted by the server for logging and prompting.",
    )