File size: 17,117 Bytes
96bb363 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 | """Tests for config/schedule types: TapPosition, TapSchedule, InferenceRamp,
InferenceRampSchedule, ActivationPolicy, TrainingRun, TrainingSchedule."""
from __future__ import annotations
import numpy as np
import pytest
from openg2g.common import ThreePhase
from openg2g.datacenter.config import (
InferenceRamp,
InferenceRampSchedule,
TrainingRun,
TrainingSchedule,
)
from openg2g.datacenter.layout import ActivationPolicy, RampActivationPolicy
from openg2g.datacenter.online import OnlineDatacenterState
from openg2g.datacenter.workloads.training import TrainingTrace
from openg2g.grid.base import BusVoltages, GridState, PhaseVoltages
from openg2g.grid.config import TapPosition, TapSchedule
_DUMMY_TRACE = TrainingTrace(t_s=np.array([0.0, 1.0]), power_w=np.array([100.0, 200.0]))
class TestTapPosition:
def test_full_three_phase(self) -> None:
"""All three phases specified should be stored as given."""
pos = TapPosition(a=1.0, b=1.05, c=1.1)
assert pos.a == 1.0
assert pos.b == 1.05
assert pos.c == 1.1
def test_partial_a_only(self) -> None:
"""Specifying only phase A should leave B and C as None."""
pos = TapPosition(a=1.1)
assert pos.a == 1.1
assert pos.b is None
assert pos.c is None
def test_partial_a_and_c(self) -> None:
"""Specifying A and C but not B should leave B as None."""
pos = TapPosition(a=1.0625, c=1.0625)
assert pos.a == 1.0625
assert pos.b is None
assert pos.c == 1.0625
def test_no_phase_raises(self) -> None:
"""Constructing with no phases should raise ValueError."""
with pytest.raises(ValueError, match="at least one phase"):
TapPosition()
def test_at_returns_schedule(self) -> None:
"""Calling .at(t) should wrap the position in a single-entry TapSchedule."""
pos = TapPosition(a=1.0, b=1.0, c=1.0)
sched = pos.at(t=100.0)
assert isinstance(sched, TapSchedule)
assert len(sched) == 1
class TestTapSchedule:
def test_pipe_composition(self) -> None:
"""The | operator should combine multiple TapSchedules into one."""
s = (
TapPosition(a=1.0, b=1.0, c=1.0).at(t=0)
| TapPosition(a=1.1).at(t=100)
| TapPosition(a=1.05, c=1.05).at(t=200)
)
assert len(s) == 3
def test_sorted_by_time(self) -> None:
"""Entries should be sorted by time regardless of pipe order."""
s = TapPosition(a=1.1).at(t=200) | TapPosition(a=1.0).at(t=0)
times = [t for t, _ in s]
assert times == [0, 200]
def test_iteration(self) -> None:
"""Iterating a schedule should yield (time, TapPosition) tuples."""
s = TapPosition(a=1.0, b=1.0, c=1.0).at(t=50)
entries = list(s)
assert len(entries) == 1
t, pos = entries[0]
assert t == 50.0
assert pos.a == 1.0
def test_bool_empty(self) -> None:
"""An empty schedule should be falsy."""
s = TapSchedule(())
assert not s
def test_bool_nonempty(self) -> None:
"""A non-empty schedule should be truthy."""
s = TapPosition(a=1.0).at(t=0)
assert s
def test_repr_partial(self) -> None:
"""repr of a partial-phase schedule should only show specified phases."""
s = TapPosition(a=1.1).at(t=100)
r = repr(s)
assert "a=1.1" in r
assert "b=" not in r
assert "c=" not in r
def test_duplicate_timestamps_raises(self) -> None:
"""Constructing a schedule with duplicate timestamps should raise ValueError."""
with pytest.raises(ValueError, match="duplicate timestamps"):
TapPosition(a=1.0).at(t=0) | TapPosition(a=1.1).at(t=0)
def test_duplicate_timestamps_three_entries(self) -> None:
"""Two of three entries sharing a timestamp should raise ValueError."""
with pytest.raises(ValueError, match="duplicate timestamps"):
TapPosition(a=1.0).at(t=0) | TapPosition(a=1.1).at(t=100) | TapPosition(a=1.05).at(t=0)
def test_duplicate_timestamps_direct_construction(self) -> None:
"""Direct TapSchedule construction with duplicates should also raise."""
with pytest.raises(ValueError, match="duplicate timestamps"):
TapSchedule(((0.0, TapPosition(a=1.0)), (0.0, TapPosition(a=1.1))))
class TestInferenceRamp:
def test_basic(self) -> None:
"""InferenceRamp should store its target fraction."""
r = InferenceRamp(target=0.5)
assert r.target == 0.5
def test_invalid_target_low(self) -> None:
"""Negative target fraction should raise ValueError."""
with pytest.raises(ValueError, match="target must be in"):
InferenceRamp(target=-0.1)
def test_invalid_target_high(self) -> None:
"""Target fraction above 1.0 should raise ValueError."""
with pytest.raises(ValueError, match="target must be in"):
InferenceRamp(target=1.5)
def test_at_returns_schedule(self) -> None:
"""Calling .at() should wrap in a single-entry InferenceRampSchedule."""
s = InferenceRamp(target=0.5).at(t_start=1000, t_end=2000)
assert isinstance(s, InferenceRampSchedule)
assert len(s) == 1
def test_at_invalid_time_order(self) -> None:
"""t_end before t_start should raise ValueError."""
with pytest.raises(ValueError, match=r"t_end.*must be >= t_start"):
InferenceRamp(target=0.5).at(t_start=2000, t_end=1000)
def test_pipe_creates_schedule(self) -> None:
"""Piping two scheduled ramps should produce an InferenceRampSchedule."""
s = InferenceRamp(target=0.5).at(t_start=100, t_end=200) | InferenceRamp(target=1.0).at(t_start=300, t_end=400)
assert isinstance(s, InferenceRampSchedule)
assert len(s) == 2
def test_pipe_three(self) -> None:
"""Chaining three ramps with | should accumulate all entries."""
s = (
InferenceRamp(target=0.5).at(t_start=100, t_end=200)
| InferenceRamp(target=1.0).at(t_start=300, t_end=400)
| InferenceRamp(target=0.3).at(t_start=500, t_end=600)
)
assert isinstance(s, InferenceRampSchedule)
assert len(s) == 3
class TestInferenceRampSchedule:
def test_fraction_before_first_ramp(self) -> None:
"""Before the first ramp starts, the active fraction should be 1.0."""
s = InferenceRamp(target=0.5).at(t_start=1000, t_end=2000)
assert s.fraction_at(0.0) == 1.0
assert s.fraction_at(999.0) == 1.0
def test_fraction_during_ramp(self) -> None:
"""During a ramp, the fraction should linearly interpolate from
the previous level to the target."""
s = InferenceRamp(target=0.0).at(t_start=1000, t_end=2000)
assert s.fraction_at(1000.0) == 1.0
assert s.fraction_at(1500.0) == pytest.approx(0.5)
assert s.fraction_at(2000.0) == pytest.approx(0.0)
def test_fraction_after_ramp(self) -> None:
"""After a ramp completes, the fraction should hold at the target."""
s = InferenceRamp(target=0.2).at(t_start=1000, t_end=2000)
assert s.fraction_at(3000.0) == pytest.approx(0.2)
def test_two_ramps(self) -> None:
"""Two sequential ramps: first ramps down to 0.2, second ramps back
up to 1.0. The fraction should hold between ramps."""
s = InferenceRamp(target=0.2).at(t_start=1000, t_end=2000) | InferenceRamp(target=1.0).at(
t_start=3000, t_end=3500
)
assert s.fraction_at(0.0) == 1.0
assert s.fraction_at(1500.0) == pytest.approx(0.6)
assert s.fraction_at(2500.0) == pytest.approx(0.2)
assert s.fraction_at(3250.0) == pytest.approx(0.6)
assert s.fraction_at(4000.0) == pytest.approx(1.0)
def test_instant_ramp(self) -> None:
"""A ramp with t_start == t_end should produce an instant step change."""
s = InferenceRamp(target=0.5).at(t_start=1000, t_end=1000)
assert s.fraction_at(999.0) == 1.0
assert s.fraction_at(1000.0) == 0.5
assert s.fraction_at(1001.0) == 0.5
def test_fraction_array(self) -> None:
"""fraction_at should accept a numpy array and return element-wise results."""
s = InferenceRamp(target=0.0).at(t_start=1000, t_end=2000)
t = np.array([0.0, 1000.0, 1500.0, 2000.0, 3000.0])
result = s.fraction_at(t)
expected = np.array([1.0, 1.0, 0.5, 0.0, 0.0])
np.testing.assert_allclose(result, expected)
def test_sorted_by_start(self) -> None:
"""Ramps piped in reverse order should still be sorted by t_start."""
s = InferenceRamp(target=1.0).at(t_start=3000, t_end=3500) | InferenceRamp(target=0.2).at(
t_start=1000, t_end=2000
)
starts = [t_start for _, t_start, _ in s]
assert starts == [1000, 3000]
class TestTrainingRun:
def test_basic(self) -> None:
"""TrainingRun should store its GPU count and trace."""
r = TrainingRun(n_gpus=2400, trace=_DUMMY_TRACE)
assert r.n_gpus == 2400
def test_negative_gpus(self) -> None:
"""Negative GPU count should raise ValueError."""
with pytest.raises(ValueError, match="n_gpus must be > 0"):
TrainingRun(n_gpus=-1, trace=_DUMMY_TRACE)
def test_default_target_peak(self) -> None:
"""target_peak_W_per_gpu should default to 400.0."""
r = TrainingRun(n_gpus=100, trace=_DUMMY_TRACE)
assert r.target_peak_W_per_gpu == 400.0
def test_at_returns_schedule(self) -> None:
"""Calling .at() should wrap in a single-entry TrainingSchedule."""
s = TrainingRun(n_gpus=100, trace=_DUMMY_TRACE).at(t_start=0, t_end=100)
assert isinstance(s, TrainingSchedule)
assert len(s) == 1
def test_at_invalid_time_order(self) -> None:
"""t_end before t_start should raise ValueError."""
with pytest.raises(ValueError, match=r"t_end.*must be >= t_start"):
TrainingRun(n_gpus=100, trace=_DUMMY_TRACE).at(t_start=200, t_end=100)
def test_scheduled_fields(self) -> None:
"""Schedule entries should expose run and time fields via tuple."""
r = TrainingRun(n_gpus=2400, trace=_DUMMY_TRACE, target_peak_W_per_gpu=500.0)
run, t_start, t_end = next(iter(r.at(t_start=100, t_end=200)))
assert run is r
assert run.n_gpus == 2400
assert run.target_peak_W_per_gpu == 500.0
assert t_start == 100
assert t_end == 200
def test_pipe_creates_schedule(self) -> None:
"""Piping two scheduled training runs should produce a TrainingSchedule."""
s = TrainingRun(n_gpus=100, trace=_DUMMY_TRACE).at(t_start=0, t_end=100) | TrainingRun(
n_gpus=50, trace=_DUMMY_TRACE
).at(t_start=200, t_end=300)
assert isinstance(s, TrainingSchedule)
assert len(s) == 2
class TestTrainingSchedule:
def test_sorted_by_start(self) -> None:
"""Runs piped in reverse order should still be sorted by t_start."""
s = TrainingRun(n_gpus=50, trace=_DUMMY_TRACE).at(t_start=200, t_end=300) | TrainingRun(
n_gpus=100, trace=_DUMMY_TRACE
).at(t_start=0, t_end=100)
starts = [t_start for _, t_start, _ in s]
assert starts == [0, 200]
def test_pipe_three(self) -> None:
"""Chaining three scheduled runs with | should accumulate all entries."""
s = (
TrainingRun(n_gpus=100, trace=_DUMMY_TRACE).at(t_start=0, t_end=100)
| TrainingRun(n_gpus=50, trace=_DUMMY_TRACE).at(t_start=200, t_end=300)
| TrainingRun(n_gpus=25, trace=_DUMMY_TRACE).at(t_start=400, t_end=500)
)
assert len(s) == 3
def test_bool(self) -> None:
"""An empty schedule should be falsy; a non-empty one should be truthy."""
s = TrainingSchedule()
assert not s
s2 = TrainingRun(n_gpus=100, trace=_DUMMY_TRACE).at(t_start=0, t_end=100)
assert s2
def test_repr(self) -> None:
"""repr should include 'TrainingRun' for debuggability."""
s = TrainingRun(n_gpus=100, trace=_DUMMY_TRACE).at(t_start=0, t_end=100)
assert "TrainingRun" in repr(s)
class TestGridState:
def test_tap_positions_default_none(self) -> None:
state = GridState(
time_s=0.0,
voltages=BusVoltages({"671": PhaseVoltages(a=1.0, b=1.0, c=1.0)}),
)
assert state.tap_positions is None
def test_tap_positions_populated(self) -> None:
taps = TapPosition(a=1.0875, b=1.0375, c=1.09375)
state = GridState(
time_s=1.0,
voltages=BusVoltages({"671": PhaseVoltages(a=0.98, b=0.99, c=1.01)}),
tap_positions=taps,
)
assert state.tap_positions is taps
assert state.tap_positions.a == 1.0875
class TestOnlineDatacenterState:
def test_construction_with_all_fields(self) -> None:
state = OnlineDatacenterState(
time_s=0.5,
power_w=ThreePhase(a=500e3, b=500e3, c=500e3),
batch_size_by_model={"8B": 128},
active_replicas_by_model={"8B": 4},
observed_itl_s_by_model={"8B": 0.05},
measured_power_w=ThreePhase(a=50e3, b=50e3, c=50e3),
measured_power_w_by_model={"8B": 120e3},
augmented_power_w_by_model={"8B": 1200e3},
augmentation_factor_by_model={"8B": 10.0},
)
assert state.power_w.a + state.power_w.b + state.power_w.c == 1500e3
assert state.measured_power_w.a + state.measured_power_w.b + state.measured_power_w.c == 150e3
assert state.augmented_power_w_by_model["8B"] == 1200e3
assert state.measured_power_w_by_model["8B"] == 120e3
assert state.augmentation_factor_by_model["8B"] == 10.0
def test_defaults(self) -> None:
state = OnlineDatacenterState(
time_s=0.0,
power_w=ThreePhase(a=0.0, b=0.0, c=0.0),
)
assert state.measured_power_w.a + state.measured_power_w.b + state.measured_power_w.c == 0.0
assert state.measured_power_w_by_model == {}
assert state.augmented_power_w_by_model == {}
assert state.augmentation_factor_by_model == {}
class TestRampActivationPolicy:
def test_all_active_before_ramp(self) -> None:
"""Before the first ramp, all servers should be active."""
schedule = InferenceRamp(target=0.5).at(t_start=1000, t_end=2000)
policy = RampActivationPolicy(schedule, num_servers=10, rng=np.random.default_rng(42))
mask = policy.active_mask(0.0)
assert mask.sum() == 10
def test_ramp_down(self) -> None:
"""After ramp completes, only the target fraction of servers should be active."""
schedule = InferenceRamp(target=0.5).at(t_start=0, t_end=0)
policy = RampActivationPolicy(schedule, num_servers=10, rng=np.random.default_rng(42))
mask = policy.active_mask(1.0)
assert mask.sum() == 5
def test_ramp_up(self) -> None:
"""Ramp up: start at 0.2, ramp to 1.0."""
schedule = InferenceRamp(target=0.2).at(t_start=0, t_end=0) | InferenceRamp(target=1.0).at(
t_start=100, t_end=100
)
policy = RampActivationPolicy(schedule, num_servers=10, rng=np.random.default_rng(42))
mask_low = policy.active_mask(50.0)
mask_high = policy.active_mask(200.0)
assert mask_low.sum() == 2
assert mask_high.sum() == 10
def test_ramp_up_servers_superset(self) -> None:
"""Servers active at low fraction should be a subset of those active at high fraction."""
schedule = InferenceRamp(target=0.3).at(t_start=0, t_end=0) | InferenceRamp(target=0.7).at(
t_start=100, t_end=100
)
policy = RampActivationPolicy(schedule, num_servers=10, rng=np.random.default_rng(42))
mask_low = policy.active_mask(50.0)
mask_high = policy.active_mask(200.0)
assert np.all(mask_high[mask_low])
def test_deterministic_with_same_seed(self) -> None:
"""Same seed should produce the same activation mask."""
schedule = InferenceRamp(target=0.5).at(t_start=0, t_end=0)
p1 = RampActivationPolicy(schedule, num_servers=20, rng=np.random.default_rng(0))
p2 = RampActivationPolicy(schedule, num_servers=20, rng=np.random.default_rng(0))
np.testing.assert_array_equal(p1.active_mask(1.0), p2.active_mask(1.0))
def test_custom_policy_subclass(self) -> None:
"""Custom ActivationPolicy subclasses should work with ServerLayout."""
class _AllOnPolicy(ActivationPolicy):
def __init__(self, n: int) -> None:
self._n = n
def active_mask(self, t: float) -> np.ndarray:
return np.ones(self._n, dtype=bool)
policy = _AllOnPolicy(10)
mask = policy.active_mask(0.0)
assert mask.sum() == 10
assert policy.active_indices(0.0).tolist() == list(range(10))
|