live / tests /test_types.py
github-actions[bot]
deploy: sync from GitHub 2026-04-18T00:48:45Z
96bb363
"""Tests for config/schedule types: TapPosition, TapSchedule, InferenceRamp,
InferenceRampSchedule, ActivationPolicy, TrainingRun, TrainingSchedule."""
from __future__ import annotations
import numpy as np
import pytest
from openg2g.common import ThreePhase
from openg2g.datacenter.config import (
InferenceRamp,
InferenceRampSchedule,
TrainingRun,
TrainingSchedule,
)
from openg2g.datacenter.layout import ActivationPolicy, RampActivationPolicy
from openg2g.datacenter.online import OnlineDatacenterState
from openg2g.datacenter.workloads.training import TrainingTrace
from openg2g.grid.base import BusVoltages, GridState, PhaseVoltages
from openg2g.grid.config import TapPosition, TapSchedule
_DUMMY_TRACE = TrainingTrace(t_s=np.array([0.0, 1.0]), power_w=np.array([100.0, 200.0]))
class TestTapPosition:
def test_full_three_phase(self) -> None:
"""All three phases specified should be stored as given."""
pos = TapPosition(a=1.0, b=1.05, c=1.1)
assert pos.a == 1.0
assert pos.b == 1.05
assert pos.c == 1.1
def test_partial_a_only(self) -> None:
"""Specifying only phase A should leave B and C as None."""
pos = TapPosition(a=1.1)
assert pos.a == 1.1
assert pos.b is None
assert pos.c is None
def test_partial_a_and_c(self) -> None:
"""Specifying A and C but not B should leave B as None."""
pos = TapPosition(a=1.0625, c=1.0625)
assert pos.a == 1.0625
assert pos.b is None
assert pos.c == 1.0625
def test_no_phase_raises(self) -> None:
"""Constructing with no phases should raise ValueError."""
with pytest.raises(ValueError, match="at least one phase"):
TapPosition()
def test_at_returns_schedule(self) -> None:
"""Calling .at(t) should wrap the position in a single-entry TapSchedule."""
pos = TapPosition(a=1.0, b=1.0, c=1.0)
sched = pos.at(t=100.0)
assert isinstance(sched, TapSchedule)
assert len(sched) == 1
class TestTapSchedule:
def test_pipe_composition(self) -> None:
"""The | operator should combine multiple TapSchedules into one."""
s = (
TapPosition(a=1.0, b=1.0, c=1.0).at(t=0)
| TapPosition(a=1.1).at(t=100)
| TapPosition(a=1.05, c=1.05).at(t=200)
)
assert len(s) == 3
def test_sorted_by_time(self) -> None:
"""Entries should be sorted by time regardless of pipe order."""
s = TapPosition(a=1.1).at(t=200) | TapPosition(a=1.0).at(t=0)
times = [t for t, _ in s]
assert times == [0, 200]
def test_iteration(self) -> None:
"""Iterating a schedule should yield (time, TapPosition) tuples."""
s = TapPosition(a=1.0, b=1.0, c=1.0).at(t=50)
entries = list(s)
assert len(entries) == 1
t, pos = entries[0]
assert t == 50.0
assert pos.a == 1.0
def test_bool_empty(self) -> None:
"""An empty schedule should be falsy."""
s = TapSchedule(())
assert not s
def test_bool_nonempty(self) -> None:
"""A non-empty schedule should be truthy."""
s = TapPosition(a=1.0).at(t=0)
assert s
def test_repr_partial(self) -> None:
"""repr of a partial-phase schedule should only show specified phases."""
s = TapPosition(a=1.1).at(t=100)
r = repr(s)
assert "a=1.1" in r
assert "b=" not in r
assert "c=" not in r
def test_duplicate_timestamps_raises(self) -> None:
"""Constructing a schedule with duplicate timestamps should raise ValueError."""
with pytest.raises(ValueError, match="duplicate timestamps"):
TapPosition(a=1.0).at(t=0) | TapPosition(a=1.1).at(t=0)
def test_duplicate_timestamps_three_entries(self) -> None:
"""Two of three entries sharing a timestamp should raise ValueError."""
with pytest.raises(ValueError, match="duplicate timestamps"):
TapPosition(a=1.0).at(t=0) | TapPosition(a=1.1).at(t=100) | TapPosition(a=1.05).at(t=0)
def test_duplicate_timestamps_direct_construction(self) -> None:
"""Direct TapSchedule construction with duplicates should also raise."""
with pytest.raises(ValueError, match="duplicate timestamps"):
TapSchedule(((0.0, TapPosition(a=1.0)), (0.0, TapPosition(a=1.1))))
class TestInferenceRamp:
def test_basic(self) -> None:
"""InferenceRamp should store its target fraction."""
r = InferenceRamp(target=0.5)
assert r.target == 0.5
def test_invalid_target_low(self) -> None:
"""Negative target fraction should raise ValueError."""
with pytest.raises(ValueError, match="target must be in"):
InferenceRamp(target=-0.1)
def test_invalid_target_high(self) -> None:
"""Target fraction above 1.0 should raise ValueError."""
with pytest.raises(ValueError, match="target must be in"):
InferenceRamp(target=1.5)
def test_at_returns_schedule(self) -> None:
"""Calling .at() should wrap in a single-entry InferenceRampSchedule."""
s = InferenceRamp(target=0.5).at(t_start=1000, t_end=2000)
assert isinstance(s, InferenceRampSchedule)
assert len(s) == 1
def test_at_invalid_time_order(self) -> None:
"""t_end before t_start should raise ValueError."""
with pytest.raises(ValueError, match=r"t_end.*must be >= t_start"):
InferenceRamp(target=0.5).at(t_start=2000, t_end=1000)
def test_pipe_creates_schedule(self) -> None:
"""Piping two scheduled ramps should produce an InferenceRampSchedule."""
s = InferenceRamp(target=0.5).at(t_start=100, t_end=200) | InferenceRamp(target=1.0).at(t_start=300, t_end=400)
assert isinstance(s, InferenceRampSchedule)
assert len(s) == 2
def test_pipe_three(self) -> None:
"""Chaining three ramps with | should accumulate all entries."""
s = (
InferenceRamp(target=0.5).at(t_start=100, t_end=200)
| InferenceRamp(target=1.0).at(t_start=300, t_end=400)
| InferenceRamp(target=0.3).at(t_start=500, t_end=600)
)
assert isinstance(s, InferenceRampSchedule)
assert len(s) == 3
class TestInferenceRampSchedule:
def test_fraction_before_first_ramp(self) -> None:
"""Before the first ramp starts, the active fraction should be 1.0."""
s = InferenceRamp(target=0.5).at(t_start=1000, t_end=2000)
assert s.fraction_at(0.0) == 1.0
assert s.fraction_at(999.0) == 1.0
def test_fraction_during_ramp(self) -> None:
"""During a ramp, the fraction should linearly interpolate from
the previous level to the target."""
s = InferenceRamp(target=0.0).at(t_start=1000, t_end=2000)
assert s.fraction_at(1000.0) == 1.0
assert s.fraction_at(1500.0) == pytest.approx(0.5)
assert s.fraction_at(2000.0) == pytest.approx(0.0)
def test_fraction_after_ramp(self) -> None:
"""After a ramp completes, the fraction should hold at the target."""
s = InferenceRamp(target=0.2).at(t_start=1000, t_end=2000)
assert s.fraction_at(3000.0) == pytest.approx(0.2)
def test_two_ramps(self) -> None:
"""Two sequential ramps: first ramps down to 0.2, second ramps back
up to 1.0. The fraction should hold between ramps."""
s = InferenceRamp(target=0.2).at(t_start=1000, t_end=2000) | InferenceRamp(target=1.0).at(
t_start=3000, t_end=3500
)
assert s.fraction_at(0.0) == 1.0
assert s.fraction_at(1500.0) == pytest.approx(0.6)
assert s.fraction_at(2500.0) == pytest.approx(0.2)
assert s.fraction_at(3250.0) == pytest.approx(0.6)
assert s.fraction_at(4000.0) == pytest.approx(1.0)
def test_instant_ramp(self) -> None:
"""A ramp with t_start == t_end should produce an instant step change."""
s = InferenceRamp(target=0.5).at(t_start=1000, t_end=1000)
assert s.fraction_at(999.0) == 1.0
assert s.fraction_at(1000.0) == 0.5
assert s.fraction_at(1001.0) == 0.5
def test_fraction_array(self) -> None:
"""fraction_at should accept a numpy array and return element-wise results."""
s = InferenceRamp(target=0.0).at(t_start=1000, t_end=2000)
t = np.array([0.0, 1000.0, 1500.0, 2000.0, 3000.0])
result = s.fraction_at(t)
expected = np.array([1.0, 1.0, 0.5, 0.0, 0.0])
np.testing.assert_allclose(result, expected)
def test_sorted_by_start(self) -> None:
"""Ramps piped in reverse order should still be sorted by t_start."""
s = InferenceRamp(target=1.0).at(t_start=3000, t_end=3500) | InferenceRamp(target=0.2).at(
t_start=1000, t_end=2000
)
starts = [t_start for _, t_start, _ in s]
assert starts == [1000, 3000]
class TestTrainingRun:
def test_basic(self) -> None:
"""TrainingRun should store its GPU count and trace."""
r = TrainingRun(n_gpus=2400, trace=_DUMMY_TRACE)
assert r.n_gpus == 2400
def test_negative_gpus(self) -> None:
"""Negative GPU count should raise ValueError."""
with pytest.raises(ValueError, match="n_gpus must be > 0"):
TrainingRun(n_gpus=-1, trace=_DUMMY_TRACE)
def test_default_target_peak(self) -> None:
"""target_peak_W_per_gpu should default to 400.0."""
r = TrainingRun(n_gpus=100, trace=_DUMMY_TRACE)
assert r.target_peak_W_per_gpu == 400.0
def test_at_returns_schedule(self) -> None:
"""Calling .at() should wrap in a single-entry TrainingSchedule."""
s = TrainingRun(n_gpus=100, trace=_DUMMY_TRACE).at(t_start=0, t_end=100)
assert isinstance(s, TrainingSchedule)
assert len(s) == 1
def test_at_invalid_time_order(self) -> None:
"""t_end before t_start should raise ValueError."""
with pytest.raises(ValueError, match=r"t_end.*must be >= t_start"):
TrainingRun(n_gpus=100, trace=_DUMMY_TRACE).at(t_start=200, t_end=100)
def test_scheduled_fields(self) -> None:
"""Schedule entries should expose run and time fields via tuple."""
r = TrainingRun(n_gpus=2400, trace=_DUMMY_TRACE, target_peak_W_per_gpu=500.0)
run, t_start, t_end = next(iter(r.at(t_start=100, t_end=200)))
assert run is r
assert run.n_gpus == 2400
assert run.target_peak_W_per_gpu == 500.0
assert t_start == 100
assert t_end == 200
def test_pipe_creates_schedule(self) -> None:
"""Piping two scheduled training runs should produce a TrainingSchedule."""
s = TrainingRun(n_gpus=100, trace=_DUMMY_TRACE).at(t_start=0, t_end=100) | TrainingRun(
n_gpus=50, trace=_DUMMY_TRACE
).at(t_start=200, t_end=300)
assert isinstance(s, TrainingSchedule)
assert len(s) == 2
class TestTrainingSchedule:
def test_sorted_by_start(self) -> None:
"""Runs piped in reverse order should still be sorted by t_start."""
s = TrainingRun(n_gpus=50, trace=_DUMMY_TRACE).at(t_start=200, t_end=300) | TrainingRun(
n_gpus=100, trace=_DUMMY_TRACE
).at(t_start=0, t_end=100)
starts = [t_start for _, t_start, _ in s]
assert starts == [0, 200]
def test_pipe_three(self) -> None:
"""Chaining three scheduled runs with | should accumulate all entries."""
s = (
TrainingRun(n_gpus=100, trace=_DUMMY_TRACE).at(t_start=0, t_end=100)
| TrainingRun(n_gpus=50, trace=_DUMMY_TRACE).at(t_start=200, t_end=300)
| TrainingRun(n_gpus=25, trace=_DUMMY_TRACE).at(t_start=400, t_end=500)
)
assert len(s) == 3
def test_bool(self) -> None:
"""An empty schedule should be falsy; a non-empty one should be truthy."""
s = TrainingSchedule()
assert not s
s2 = TrainingRun(n_gpus=100, trace=_DUMMY_TRACE).at(t_start=0, t_end=100)
assert s2
def test_repr(self) -> None:
"""repr should include 'TrainingRun' for debuggability."""
s = TrainingRun(n_gpus=100, trace=_DUMMY_TRACE).at(t_start=0, t_end=100)
assert "TrainingRun" in repr(s)
class TestGridState:
def test_tap_positions_default_none(self) -> None:
state = GridState(
time_s=0.0,
voltages=BusVoltages({"671": PhaseVoltages(a=1.0, b=1.0, c=1.0)}),
)
assert state.tap_positions is None
def test_tap_positions_populated(self) -> None:
taps = TapPosition(a=1.0875, b=1.0375, c=1.09375)
state = GridState(
time_s=1.0,
voltages=BusVoltages({"671": PhaseVoltages(a=0.98, b=0.99, c=1.01)}),
tap_positions=taps,
)
assert state.tap_positions is taps
assert state.tap_positions.a == 1.0875
class TestOnlineDatacenterState:
def test_construction_with_all_fields(self) -> None:
state = OnlineDatacenterState(
time_s=0.5,
power_w=ThreePhase(a=500e3, b=500e3, c=500e3),
batch_size_by_model={"8B": 128},
active_replicas_by_model={"8B": 4},
observed_itl_s_by_model={"8B": 0.05},
measured_power_w=ThreePhase(a=50e3, b=50e3, c=50e3),
measured_power_w_by_model={"8B": 120e3},
augmented_power_w_by_model={"8B": 1200e3},
augmentation_factor_by_model={"8B": 10.0},
)
assert state.power_w.a + state.power_w.b + state.power_w.c == 1500e3
assert state.measured_power_w.a + state.measured_power_w.b + state.measured_power_w.c == 150e3
assert state.augmented_power_w_by_model["8B"] == 1200e3
assert state.measured_power_w_by_model["8B"] == 120e3
assert state.augmentation_factor_by_model["8B"] == 10.0
def test_defaults(self) -> None:
state = OnlineDatacenterState(
time_s=0.0,
power_w=ThreePhase(a=0.0, b=0.0, c=0.0),
)
assert state.measured_power_w.a + state.measured_power_w.b + state.measured_power_w.c == 0.0
assert state.measured_power_w_by_model == {}
assert state.augmented_power_w_by_model == {}
assert state.augmentation_factor_by_model == {}
class TestRampActivationPolicy:
def test_all_active_before_ramp(self) -> None:
"""Before the first ramp, all servers should be active."""
schedule = InferenceRamp(target=0.5).at(t_start=1000, t_end=2000)
policy = RampActivationPolicy(schedule, num_servers=10, rng=np.random.default_rng(42))
mask = policy.active_mask(0.0)
assert mask.sum() == 10
def test_ramp_down(self) -> None:
"""After ramp completes, only the target fraction of servers should be active."""
schedule = InferenceRamp(target=0.5).at(t_start=0, t_end=0)
policy = RampActivationPolicy(schedule, num_servers=10, rng=np.random.default_rng(42))
mask = policy.active_mask(1.0)
assert mask.sum() == 5
def test_ramp_up(self) -> None:
"""Ramp up: start at 0.2, ramp to 1.0."""
schedule = InferenceRamp(target=0.2).at(t_start=0, t_end=0) | InferenceRamp(target=1.0).at(
t_start=100, t_end=100
)
policy = RampActivationPolicy(schedule, num_servers=10, rng=np.random.default_rng(42))
mask_low = policy.active_mask(50.0)
mask_high = policy.active_mask(200.0)
assert mask_low.sum() == 2
assert mask_high.sum() == 10
def test_ramp_up_servers_superset(self) -> None:
"""Servers active at low fraction should be a subset of those active at high fraction."""
schedule = InferenceRamp(target=0.3).at(t_start=0, t_end=0) | InferenceRamp(target=0.7).at(
t_start=100, t_end=100
)
policy = RampActivationPolicy(schedule, num_servers=10, rng=np.random.default_rng(42))
mask_low = policy.active_mask(50.0)
mask_high = policy.active_mask(200.0)
assert np.all(mask_high[mask_low])
def test_deterministic_with_same_seed(self) -> None:
"""Same seed should produce the same activation mask."""
schedule = InferenceRamp(target=0.5).at(t_start=0, t_end=0)
p1 = RampActivationPolicy(schedule, num_servers=20, rng=np.random.default_rng(0))
p2 = RampActivationPolicy(schedule, num_servers=20, rng=np.random.default_rng(0))
np.testing.assert_array_equal(p1.active_mask(1.0), p2.active_mask(1.0))
def test_custom_policy_subclass(self) -> None:
"""Custom ActivationPolicy subclasses should work with ServerLayout."""
class _AllOnPolicy(ActivationPolicy):
def __init__(self, n: int) -> None:
self._n = n
def active_mask(self, t: float) -> np.ndarray:
return np.ones(self._n, dtype=bool)
policy = _AllOnPolicy(10)
mask = policy.active_mask(0.0)
assert mask.sum() == 10
assert policy.active_indices(0.0).tolist() == list(range(10))