| """Tests for config/schedule types: TapPosition, TapSchedule, InferenceRamp, |
| InferenceRampSchedule, ActivationPolicy, TrainingRun, TrainingSchedule.""" |
|
|
| from __future__ import annotations |
|
|
| import numpy as np |
| import pytest |
|
|
| from openg2g.common import ThreePhase |
| from openg2g.datacenter.config import ( |
| InferenceRamp, |
| InferenceRampSchedule, |
| TrainingRun, |
| TrainingSchedule, |
| ) |
| from openg2g.datacenter.layout import ActivationPolicy, RampActivationPolicy |
| from openg2g.datacenter.online import OnlineDatacenterState |
| from openg2g.datacenter.workloads.training import TrainingTrace |
| from openg2g.grid.base import BusVoltages, GridState, PhaseVoltages |
| from openg2g.grid.config import TapPosition, TapSchedule |
|
|
| _DUMMY_TRACE = TrainingTrace(t_s=np.array([0.0, 1.0]), power_w=np.array([100.0, 200.0])) |
|
|
|
|
| class TestTapPosition: |
| def test_full_three_phase(self) -> None: |
| """All three phases specified should be stored as given.""" |
| pos = TapPosition(a=1.0, b=1.05, c=1.1) |
| assert pos.a == 1.0 |
| assert pos.b == 1.05 |
| assert pos.c == 1.1 |
|
|
| def test_partial_a_only(self) -> None: |
| """Specifying only phase A should leave B and C as None.""" |
| pos = TapPosition(a=1.1) |
| assert pos.a == 1.1 |
| assert pos.b is None |
| assert pos.c is None |
|
|
| def test_partial_a_and_c(self) -> None: |
| """Specifying A and C but not B should leave B as None.""" |
| pos = TapPosition(a=1.0625, c=1.0625) |
| assert pos.a == 1.0625 |
| assert pos.b is None |
| assert pos.c == 1.0625 |
|
|
| def test_no_phase_raises(self) -> None: |
| """Constructing with no phases should raise ValueError.""" |
| with pytest.raises(ValueError, match="at least one phase"): |
| TapPosition() |
|
|
| def test_at_returns_schedule(self) -> None: |
| """Calling .at(t) should wrap the position in a single-entry TapSchedule.""" |
| pos = TapPosition(a=1.0, b=1.0, c=1.0) |
| sched = pos.at(t=100.0) |
| assert isinstance(sched, TapSchedule) |
| assert len(sched) == 1 |
|
|
|
|
| class TestTapSchedule: |
| def test_pipe_composition(self) -> None: |
| """The | operator should combine multiple TapSchedules into one.""" |
| s = ( |
| TapPosition(a=1.0, b=1.0, c=1.0).at(t=0) |
| | TapPosition(a=1.1).at(t=100) |
| | TapPosition(a=1.05, c=1.05).at(t=200) |
| ) |
| assert len(s) == 3 |
|
|
| def test_sorted_by_time(self) -> None: |
| """Entries should be sorted by time regardless of pipe order.""" |
| s = TapPosition(a=1.1).at(t=200) | TapPosition(a=1.0).at(t=0) |
| times = [t for t, _ in s] |
| assert times == [0, 200] |
|
|
| def test_iteration(self) -> None: |
| """Iterating a schedule should yield (time, TapPosition) tuples.""" |
| s = TapPosition(a=1.0, b=1.0, c=1.0).at(t=50) |
| entries = list(s) |
| assert len(entries) == 1 |
| t, pos = entries[0] |
| assert t == 50.0 |
| assert pos.a == 1.0 |
|
|
| def test_bool_empty(self) -> None: |
| """An empty schedule should be falsy.""" |
| s = TapSchedule(()) |
| assert not s |
|
|
| def test_bool_nonempty(self) -> None: |
| """A non-empty schedule should be truthy.""" |
| s = TapPosition(a=1.0).at(t=0) |
| assert s |
|
|
| def test_repr_partial(self) -> None: |
| """repr of a partial-phase schedule should only show specified phases.""" |
| s = TapPosition(a=1.1).at(t=100) |
| r = repr(s) |
| assert "a=1.1" in r |
| assert "b=" not in r |
| assert "c=" not in r |
|
|
| def test_duplicate_timestamps_raises(self) -> None: |
| """Constructing a schedule with duplicate timestamps should raise ValueError.""" |
| with pytest.raises(ValueError, match="duplicate timestamps"): |
| TapPosition(a=1.0).at(t=0) | TapPosition(a=1.1).at(t=0) |
|
|
| def test_duplicate_timestamps_three_entries(self) -> None: |
| """Two of three entries sharing a timestamp should raise ValueError.""" |
| with pytest.raises(ValueError, match="duplicate timestamps"): |
| TapPosition(a=1.0).at(t=0) | TapPosition(a=1.1).at(t=100) | TapPosition(a=1.05).at(t=0) |
|
|
| def test_duplicate_timestamps_direct_construction(self) -> None: |
| """Direct TapSchedule construction with duplicates should also raise.""" |
| with pytest.raises(ValueError, match="duplicate timestamps"): |
| TapSchedule(((0.0, TapPosition(a=1.0)), (0.0, TapPosition(a=1.1)))) |
|
|
|
|
| class TestInferenceRamp: |
| def test_basic(self) -> None: |
| """InferenceRamp should store its target fraction.""" |
| r = InferenceRamp(target=0.5) |
| assert r.target == 0.5 |
|
|
| def test_invalid_target_low(self) -> None: |
| """Negative target fraction should raise ValueError.""" |
| with pytest.raises(ValueError, match="target must be in"): |
| InferenceRamp(target=-0.1) |
|
|
| def test_invalid_target_high(self) -> None: |
| """Target fraction above 1.0 should raise ValueError.""" |
| with pytest.raises(ValueError, match="target must be in"): |
| InferenceRamp(target=1.5) |
|
|
| def test_at_returns_schedule(self) -> None: |
| """Calling .at() should wrap in a single-entry InferenceRampSchedule.""" |
| s = InferenceRamp(target=0.5).at(t_start=1000, t_end=2000) |
| assert isinstance(s, InferenceRampSchedule) |
| assert len(s) == 1 |
|
|
| def test_at_invalid_time_order(self) -> None: |
| """t_end before t_start should raise ValueError.""" |
| with pytest.raises(ValueError, match=r"t_end.*must be >= t_start"): |
| InferenceRamp(target=0.5).at(t_start=2000, t_end=1000) |
|
|
| def test_pipe_creates_schedule(self) -> None: |
| """Piping two scheduled ramps should produce an InferenceRampSchedule.""" |
| s = InferenceRamp(target=0.5).at(t_start=100, t_end=200) | InferenceRamp(target=1.0).at(t_start=300, t_end=400) |
| assert isinstance(s, InferenceRampSchedule) |
| assert len(s) == 2 |
|
|
| def test_pipe_three(self) -> None: |
| """Chaining three ramps with | should accumulate all entries.""" |
| s = ( |
| InferenceRamp(target=0.5).at(t_start=100, t_end=200) |
| | InferenceRamp(target=1.0).at(t_start=300, t_end=400) |
| | InferenceRamp(target=0.3).at(t_start=500, t_end=600) |
| ) |
| assert isinstance(s, InferenceRampSchedule) |
| assert len(s) == 3 |
|
|
|
|
| class TestInferenceRampSchedule: |
| def test_fraction_before_first_ramp(self) -> None: |
| """Before the first ramp starts, the active fraction should be 1.0.""" |
| s = InferenceRamp(target=0.5).at(t_start=1000, t_end=2000) |
| assert s.fraction_at(0.0) == 1.0 |
| assert s.fraction_at(999.0) == 1.0 |
|
|
| def test_fraction_during_ramp(self) -> None: |
| """During a ramp, the fraction should linearly interpolate from |
| the previous level to the target.""" |
| s = InferenceRamp(target=0.0).at(t_start=1000, t_end=2000) |
| assert s.fraction_at(1000.0) == 1.0 |
| assert s.fraction_at(1500.0) == pytest.approx(0.5) |
| assert s.fraction_at(2000.0) == pytest.approx(0.0) |
|
|
| def test_fraction_after_ramp(self) -> None: |
| """After a ramp completes, the fraction should hold at the target.""" |
| s = InferenceRamp(target=0.2).at(t_start=1000, t_end=2000) |
| assert s.fraction_at(3000.0) == pytest.approx(0.2) |
|
|
| def test_two_ramps(self) -> None: |
| """Two sequential ramps: first ramps down to 0.2, second ramps back |
| up to 1.0. The fraction should hold between ramps.""" |
| s = InferenceRamp(target=0.2).at(t_start=1000, t_end=2000) | InferenceRamp(target=1.0).at( |
| t_start=3000, t_end=3500 |
| ) |
| assert s.fraction_at(0.0) == 1.0 |
| assert s.fraction_at(1500.0) == pytest.approx(0.6) |
| assert s.fraction_at(2500.0) == pytest.approx(0.2) |
| assert s.fraction_at(3250.0) == pytest.approx(0.6) |
| assert s.fraction_at(4000.0) == pytest.approx(1.0) |
|
|
| def test_instant_ramp(self) -> None: |
| """A ramp with t_start == t_end should produce an instant step change.""" |
| s = InferenceRamp(target=0.5).at(t_start=1000, t_end=1000) |
| assert s.fraction_at(999.0) == 1.0 |
| assert s.fraction_at(1000.0) == 0.5 |
| assert s.fraction_at(1001.0) == 0.5 |
|
|
| def test_fraction_array(self) -> None: |
| """fraction_at should accept a numpy array and return element-wise results.""" |
| s = InferenceRamp(target=0.0).at(t_start=1000, t_end=2000) |
| t = np.array([0.0, 1000.0, 1500.0, 2000.0, 3000.0]) |
| result = s.fraction_at(t) |
| expected = np.array([1.0, 1.0, 0.5, 0.0, 0.0]) |
| np.testing.assert_allclose(result, expected) |
|
|
| def test_sorted_by_start(self) -> None: |
| """Ramps piped in reverse order should still be sorted by t_start.""" |
| s = InferenceRamp(target=1.0).at(t_start=3000, t_end=3500) | InferenceRamp(target=0.2).at( |
| t_start=1000, t_end=2000 |
| ) |
| starts = [t_start for _, t_start, _ in s] |
| assert starts == [1000, 3000] |
|
|
|
|
| class TestTrainingRun: |
| def test_basic(self) -> None: |
| """TrainingRun should store its GPU count and trace.""" |
| r = TrainingRun(n_gpus=2400, trace=_DUMMY_TRACE) |
| assert r.n_gpus == 2400 |
|
|
| def test_negative_gpus(self) -> None: |
| """Negative GPU count should raise ValueError.""" |
| with pytest.raises(ValueError, match="n_gpus must be > 0"): |
| TrainingRun(n_gpus=-1, trace=_DUMMY_TRACE) |
|
|
| def test_default_target_peak(self) -> None: |
| """target_peak_W_per_gpu should default to 400.0.""" |
| r = TrainingRun(n_gpus=100, trace=_DUMMY_TRACE) |
| assert r.target_peak_W_per_gpu == 400.0 |
|
|
| def test_at_returns_schedule(self) -> None: |
| """Calling .at() should wrap in a single-entry TrainingSchedule.""" |
| s = TrainingRun(n_gpus=100, trace=_DUMMY_TRACE).at(t_start=0, t_end=100) |
| assert isinstance(s, TrainingSchedule) |
| assert len(s) == 1 |
|
|
| def test_at_invalid_time_order(self) -> None: |
| """t_end before t_start should raise ValueError.""" |
| with pytest.raises(ValueError, match=r"t_end.*must be >= t_start"): |
| TrainingRun(n_gpus=100, trace=_DUMMY_TRACE).at(t_start=200, t_end=100) |
|
|
| def test_scheduled_fields(self) -> None: |
| """Schedule entries should expose run and time fields via tuple.""" |
| r = TrainingRun(n_gpus=2400, trace=_DUMMY_TRACE, target_peak_W_per_gpu=500.0) |
| run, t_start, t_end = next(iter(r.at(t_start=100, t_end=200))) |
| assert run is r |
| assert run.n_gpus == 2400 |
| assert run.target_peak_W_per_gpu == 500.0 |
| assert t_start == 100 |
| assert t_end == 200 |
|
|
| def test_pipe_creates_schedule(self) -> None: |
| """Piping two scheduled training runs should produce a TrainingSchedule.""" |
| s = TrainingRun(n_gpus=100, trace=_DUMMY_TRACE).at(t_start=0, t_end=100) | TrainingRun( |
| n_gpus=50, trace=_DUMMY_TRACE |
| ).at(t_start=200, t_end=300) |
| assert isinstance(s, TrainingSchedule) |
| assert len(s) == 2 |
|
|
|
|
| class TestTrainingSchedule: |
| def test_sorted_by_start(self) -> None: |
| """Runs piped in reverse order should still be sorted by t_start.""" |
| s = TrainingRun(n_gpus=50, trace=_DUMMY_TRACE).at(t_start=200, t_end=300) | TrainingRun( |
| n_gpus=100, trace=_DUMMY_TRACE |
| ).at(t_start=0, t_end=100) |
| starts = [t_start for _, t_start, _ in s] |
| assert starts == [0, 200] |
|
|
| def test_pipe_three(self) -> None: |
| """Chaining three scheduled runs with | should accumulate all entries.""" |
| s = ( |
| TrainingRun(n_gpus=100, trace=_DUMMY_TRACE).at(t_start=0, t_end=100) |
| | TrainingRun(n_gpus=50, trace=_DUMMY_TRACE).at(t_start=200, t_end=300) |
| | TrainingRun(n_gpus=25, trace=_DUMMY_TRACE).at(t_start=400, t_end=500) |
| ) |
| assert len(s) == 3 |
|
|
| def test_bool(self) -> None: |
| """An empty schedule should be falsy; a non-empty one should be truthy.""" |
| s = TrainingSchedule() |
| assert not s |
| s2 = TrainingRun(n_gpus=100, trace=_DUMMY_TRACE).at(t_start=0, t_end=100) |
| assert s2 |
|
|
| def test_repr(self) -> None: |
| """repr should include 'TrainingRun' for debuggability.""" |
| s = TrainingRun(n_gpus=100, trace=_DUMMY_TRACE).at(t_start=0, t_end=100) |
| assert "TrainingRun" in repr(s) |
|
|
|
|
| class TestGridState: |
| def test_tap_positions_default_none(self) -> None: |
| state = GridState( |
| time_s=0.0, |
| voltages=BusVoltages({"671": PhaseVoltages(a=1.0, b=1.0, c=1.0)}), |
| ) |
| assert state.tap_positions is None |
|
|
| def test_tap_positions_populated(self) -> None: |
| taps = TapPosition(a=1.0875, b=1.0375, c=1.09375) |
| state = GridState( |
| time_s=1.0, |
| voltages=BusVoltages({"671": PhaseVoltages(a=0.98, b=0.99, c=1.01)}), |
| tap_positions=taps, |
| ) |
| assert state.tap_positions is taps |
| assert state.tap_positions.a == 1.0875 |
|
|
|
|
| class TestOnlineDatacenterState: |
| def test_construction_with_all_fields(self) -> None: |
| state = OnlineDatacenterState( |
| time_s=0.5, |
| power_w=ThreePhase(a=500e3, b=500e3, c=500e3), |
| batch_size_by_model={"8B": 128}, |
| active_replicas_by_model={"8B": 4}, |
| observed_itl_s_by_model={"8B": 0.05}, |
| measured_power_w=ThreePhase(a=50e3, b=50e3, c=50e3), |
| measured_power_w_by_model={"8B": 120e3}, |
| augmented_power_w_by_model={"8B": 1200e3}, |
| augmentation_factor_by_model={"8B": 10.0}, |
| ) |
| assert state.power_w.a + state.power_w.b + state.power_w.c == 1500e3 |
| assert state.measured_power_w.a + state.measured_power_w.b + state.measured_power_w.c == 150e3 |
| assert state.augmented_power_w_by_model["8B"] == 1200e3 |
| assert state.measured_power_w_by_model["8B"] == 120e3 |
| assert state.augmentation_factor_by_model["8B"] == 10.0 |
|
|
| def test_defaults(self) -> None: |
| state = OnlineDatacenterState( |
| time_s=0.0, |
| power_w=ThreePhase(a=0.0, b=0.0, c=0.0), |
| ) |
| assert state.measured_power_w.a + state.measured_power_w.b + state.measured_power_w.c == 0.0 |
| assert state.measured_power_w_by_model == {} |
| assert state.augmented_power_w_by_model == {} |
| assert state.augmentation_factor_by_model == {} |
|
|
|
|
| class TestRampActivationPolicy: |
| def test_all_active_before_ramp(self) -> None: |
| """Before the first ramp, all servers should be active.""" |
| schedule = InferenceRamp(target=0.5).at(t_start=1000, t_end=2000) |
| policy = RampActivationPolicy(schedule, num_servers=10, rng=np.random.default_rng(42)) |
| mask = policy.active_mask(0.0) |
| assert mask.sum() == 10 |
|
|
| def test_ramp_down(self) -> None: |
| """After ramp completes, only the target fraction of servers should be active.""" |
| schedule = InferenceRamp(target=0.5).at(t_start=0, t_end=0) |
| policy = RampActivationPolicy(schedule, num_servers=10, rng=np.random.default_rng(42)) |
| mask = policy.active_mask(1.0) |
| assert mask.sum() == 5 |
|
|
| def test_ramp_up(self) -> None: |
| """Ramp up: start at 0.2, ramp to 1.0.""" |
| schedule = InferenceRamp(target=0.2).at(t_start=0, t_end=0) | InferenceRamp(target=1.0).at( |
| t_start=100, t_end=100 |
| ) |
| policy = RampActivationPolicy(schedule, num_servers=10, rng=np.random.default_rng(42)) |
| mask_low = policy.active_mask(50.0) |
| mask_high = policy.active_mask(200.0) |
| assert mask_low.sum() == 2 |
| assert mask_high.sum() == 10 |
|
|
| def test_ramp_up_servers_superset(self) -> None: |
| """Servers active at low fraction should be a subset of those active at high fraction.""" |
| schedule = InferenceRamp(target=0.3).at(t_start=0, t_end=0) | InferenceRamp(target=0.7).at( |
| t_start=100, t_end=100 |
| ) |
| policy = RampActivationPolicy(schedule, num_servers=10, rng=np.random.default_rng(42)) |
| mask_low = policy.active_mask(50.0) |
| mask_high = policy.active_mask(200.0) |
| assert np.all(mask_high[mask_low]) |
|
|
| def test_deterministic_with_same_seed(self) -> None: |
| """Same seed should produce the same activation mask.""" |
| schedule = InferenceRamp(target=0.5).at(t_start=0, t_end=0) |
| p1 = RampActivationPolicy(schedule, num_servers=20, rng=np.random.default_rng(0)) |
| p2 = RampActivationPolicy(schedule, num_servers=20, rng=np.random.default_rng(0)) |
| np.testing.assert_array_equal(p1.active_mask(1.0), p2.active_mask(1.0)) |
|
|
| def test_custom_policy_subclass(self) -> None: |
| """Custom ActivationPolicy subclasses should work with ServerLayout.""" |
|
|
| class _AllOnPolicy(ActivationPolicy): |
| def __init__(self, n: int) -> None: |
| self._n = n |
|
|
| def active_mask(self, t: float) -> np.ndarray: |
| return np.ones(self._n, dtype=bool) |
|
|
| policy = _AllOnPolicy(10) |
| mask = policy.active_mask(0.0) |
| assert mask.sum() == 10 |
| assert policy.active_indices(0.0).tolist() == list(range(10)) |
|
|