live / tests /test_batch_size_schedule.py
github-actions[bot]
deploy: sync from GitHub 2026-04-18T00:48:45Z
96bb363
"""Tests for BatchSizeChange, BatchSizeSchedule, and BatchSizeScheduleController."""
from __future__ import annotations
from fractions import Fraction
from unittest.mock import MagicMock
import pytest
from openg2g.controller.batch_size_schedule import BatchSizeChange, BatchSizeSchedule, BatchSizeScheduleController
from openg2g.datacenter.command import SetBatchSize
class TestBatchSizeChange:
def test_basic_construction(self) -> None:
c = BatchSizeChange(batch_size=64)
assert c.batch_size == 64
assert c.ramp_up_rate == 0.0
def test_with_ramp(self) -> None:
c = BatchSizeChange(batch_size=32, ramp_up_rate=4.0)
assert c.batch_size == 32
assert c.ramp_up_rate == 4.0
def test_invalid_batch_size(self) -> None:
with pytest.raises(ValueError, match="batch_size must be positive"):
BatchSizeChange(batch_size=0)
with pytest.raises(ValueError, match="batch_size must be positive"):
BatchSizeChange(batch_size=-1)
def test_invalid_ramp_up_rate(self) -> None:
with pytest.raises(ValueError, match="ramp_up_rate must be >= 0"):
BatchSizeChange(batch_size=32, ramp_up_rate=-1.0)
def test_at_creates_schedule(self) -> None:
s = BatchSizeChange(48).at(40.0)
assert isinstance(s, BatchSizeSchedule)
assert len(s) == 1
entries = list(s)
assert entries[0][0] == 40.0
assert entries[0][1].batch_size == 48
class TestBatchSizeSchedule:
def test_pipe_composition(self) -> None:
s = BatchSizeChange(48).at(40) | BatchSizeChange(32).at(60) | BatchSizeChange(64).at(80)
assert len(s) == 3
entries = list(s)
assert entries[0][0] == 40
assert entries[1][0] == 60
assert entries[2][0] == 80
def test_sorted_by_time(self) -> None:
s = BatchSizeChange(32).at(60) | BatchSizeChange(48).at(40)
entries = list(s)
assert entries[0][0] == 40
assert entries[1][0] == 60
def test_bool(self) -> None:
assert bool(BatchSizeChange(32).at(0))
assert not bool(BatchSizeSchedule(()))
def test_repr(self) -> None:
s = BatchSizeChange(48).at(40) | BatchSizeChange(32, ramp_up_rate=4).at(60)
r = repr(s)
assert "BatchSizeChange(48)" in r
assert "ramp_up_rate=4" in r
def test_duplicate_timestamps_raises(self) -> None:
with pytest.raises(ValueError, match="duplicate timestamps"):
BatchSizeChange(48).at(40) | BatchSizeChange(32).at(40)
def test_duplicate_timestamps_three_entries(self) -> None:
with pytest.raises(ValueError, match="duplicate timestamps"):
BatchSizeChange(48).at(40) | BatchSizeChange(32).at(60) | BatchSizeChange(64).at(40)
def test_duplicate_timestamps_direct_construction(self) -> None:
with pytest.raises(ValueError, match="duplicate timestamps"):
BatchSizeSchedule(((40.0, BatchSizeChange(48)), (40.0, BatchSizeChange(32))))
class TestBatchSizeScheduleController:
def _make_clock(self, time_s: float) -> MagicMock:
clock = MagicMock()
clock.time_s = time_s
return clock
def test_emits_at_scheduled_time(self) -> None:
schedule = BatchSizeChange(48).at(10) | BatchSizeChange(32).at(20)
ctrl = BatchSizeScheduleController(
schedules={"model-a": schedule},
dt_s=Fraction(1),
)
dc = MagicMock()
grid = MagicMock()
events = MagicMock()
# Before first event
action = ctrl.step(self._make_clock(5.0), dc, grid, events)
assert len(action) == 0
# At first event
action = ctrl.step(self._make_clock(10.0), dc, grid, events)
assert len(action) == 1
cmd = action[0]
assert isinstance(cmd, SetBatchSize)
assert cmd.batch_size_by_model["model-a"] == 48
# Between events
action = ctrl.step(self._make_clock(15.0), dc, grid, events)
assert len(action) == 0
# At second event
action = ctrl.step(self._make_clock(20.0), dc, grid, events)
assert len(action) == 1
assert isinstance(action[0], SetBatchSize)
assert action[0].batch_size_by_model["model-a"] == 32
def test_multiple_models(self) -> None:
schedules = {
"model-a": BatchSizeChange(48).at(10),
"model-b": BatchSizeChange(64).at(10),
}
ctrl = BatchSizeScheduleController(schedules=schedules, dt_s=Fraction(1))
dc = MagicMock()
grid = MagicMock()
events = MagicMock()
action = ctrl.step(self._make_clock(10.0), dc, grid, events)
assert len(action) == 1
cmd = action[0]
assert isinstance(cmd, SetBatchSize)
assert cmd.batch_size_by_model["model-a"] == 48
assert cmd.batch_size_by_model["model-b"] == 64
def test_ramp_up_rate_per_model(self) -> None:
schedule = BatchSizeChange(32, ramp_up_rate=4.0).at(10)
ctrl = BatchSizeScheduleController(
schedules={"model-a": schedule},
dt_s=Fraction(1),
)
dc = MagicMock()
grid = MagicMock()
events = MagicMock()
action = ctrl.step(self._make_clock(10.0), dc, grid, events)
cmd = action[0]
assert isinstance(cmd, SetBatchSize)
assert cmd.ramp_up_rate_by_model == {"model-a": 4.0}
def test_no_ramp_up_rate_when_zero(self) -> None:
schedule = BatchSizeChange(32).at(10)
ctrl = BatchSizeScheduleController(
schedules={"model-a": schedule},
dt_s=Fraction(1),
)
dc = MagicMock()
grid = MagicMock()
events = MagicMock()
action = ctrl.step(self._make_clock(10.0), dc, grid, events)
cmd = action[0]
assert isinstance(cmd, SetBatchSize)
assert cmd.ramp_up_rate_by_model == {}
def test_mixed_ramp_rates_per_model(self) -> None:
schedules = {
"model-a": BatchSizeChange(32, ramp_up_rate=4.0).at(10),
"model-b": BatchSizeChange(64, ramp_up_rate=8.0).at(10),
"model-c": BatchSizeChange(128).at(10),
}
ctrl = BatchSizeScheduleController(schedules=schedules, dt_s=Fraction(1))
dc = MagicMock()
grid = MagicMock()
events = MagicMock()
action = ctrl.step(self._make_clock(10.0), dc, grid, events)
cmd = action[0]
assert isinstance(cmd, SetBatchSize)
assert cmd.ramp_up_rate_by_model == {"model-a": 4.0, "model-b": 8.0}
assert "model-c" not in cmd.ramp_up_rate_by_model
def test_dt_s(self) -> None:
ctrl = BatchSizeScheduleController(schedules={}, dt_s=Fraction(2))
assert ctrl.dt_s == Fraction(2)
def test_empty_schedule(self) -> None:
ctrl = BatchSizeScheduleController(schedules={}, dt_s=Fraction(1))
dc = MagicMock()
grid = MagicMock()
events = MagicMock()
action = ctrl.step(self._make_clock(100.0), dc, grid, events)
assert len(action) == 0