| """Tests for BatchSizeChange, BatchSizeSchedule, and BatchSizeScheduleController.""" |
|
|
| from __future__ import annotations |
|
|
| from fractions import Fraction |
| from unittest.mock import MagicMock |
|
|
| import pytest |
|
|
| from openg2g.controller.batch_size_schedule import BatchSizeChange, BatchSizeSchedule, BatchSizeScheduleController |
| from openg2g.datacenter.command import SetBatchSize |
|
|
|
|
| class TestBatchSizeChange: |
| def test_basic_construction(self) -> None: |
| c = BatchSizeChange(batch_size=64) |
| assert c.batch_size == 64 |
| assert c.ramp_up_rate == 0.0 |
|
|
| def test_with_ramp(self) -> None: |
| c = BatchSizeChange(batch_size=32, ramp_up_rate=4.0) |
| assert c.batch_size == 32 |
| assert c.ramp_up_rate == 4.0 |
|
|
| def test_invalid_batch_size(self) -> None: |
| with pytest.raises(ValueError, match="batch_size must be positive"): |
| BatchSizeChange(batch_size=0) |
| with pytest.raises(ValueError, match="batch_size must be positive"): |
| BatchSizeChange(batch_size=-1) |
|
|
| def test_invalid_ramp_up_rate(self) -> None: |
| with pytest.raises(ValueError, match="ramp_up_rate must be >= 0"): |
| BatchSizeChange(batch_size=32, ramp_up_rate=-1.0) |
|
|
| def test_at_creates_schedule(self) -> None: |
| s = BatchSizeChange(48).at(40.0) |
| assert isinstance(s, BatchSizeSchedule) |
| assert len(s) == 1 |
| entries = list(s) |
| assert entries[0][0] == 40.0 |
| assert entries[0][1].batch_size == 48 |
|
|
|
|
| class TestBatchSizeSchedule: |
| def test_pipe_composition(self) -> None: |
| s = BatchSizeChange(48).at(40) | BatchSizeChange(32).at(60) | BatchSizeChange(64).at(80) |
| assert len(s) == 3 |
| entries = list(s) |
| assert entries[0][0] == 40 |
| assert entries[1][0] == 60 |
| assert entries[2][0] == 80 |
|
|
| def test_sorted_by_time(self) -> None: |
| s = BatchSizeChange(32).at(60) | BatchSizeChange(48).at(40) |
| entries = list(s) |
| assert entries[0][0] == 40 |
| assert entries[1][0] == 60 |
|
|
| def test_bool(self) -> None: |
| assert bool(BatchSizeChange(32).at(0)) |
| assert not bool(BatchSizeSchedule(())) |
|
|
| def test_repr(self) -> None: |
| s = BatchSizeChange(48).at(40) | BatchSizeChange(32, ramp_up_rate=4).at(60) |
| r = repr(s) |
| assert "BatchSizeChange(48)" in r |
| assert "ramp_up_rate=4" in r |
|
|
| def test_duplicate_timestamps_raises(self) -> None: |
| with pytest.raises(ValueError, match="duplicate timestamps"): |
| BatchSizeChange(48).at(40) | BatchSizeChange(32).at(40) |
|
|
| def test_duplicate_timestamps_three_entries(self) -> None: |
| with pytest.raises(ValueError, match="duplicate timestamps"): |
| BatchSizeChange(48).at(40) | BatchSizeChange(32).at(60) | BatchSizeChange(64).at(40) |
|
|
| def test_duplicate_timestamps_direct_construction(self) -> None: |
| with pytest.raises(ValueError, match="duplicate timestamps"): |
| BatchSizeSchedule(((40.0, BatchSizeChange(48)), (40.0, BatchSizeChange(32)))) |
|
|
|
|
| class TestBatchSizeScheduleController: |
| def _make_clock(self, time_s: float) -> MagicMock: |
| clock = MagicMock() |
| clock.time_s = time_s |
| return clock |
|
|
| def test_emits_at_scheduled_time(self) -> None: |
| schedule = BatchSizeChange(48).at(10) | BatchSizeChange(32).at(20) |
| ctrl = BatchSizeScheduleController( |
| schedules={"model-a": schedule}, |
| dt_s=Fraction(1), |
| ) |
| dc = MagicMock() |
| grid = MagicMock() |
| events = MagicMock() |
|
|
| |
| action = ctrl.step(self._make_clock(5.0), dc, grid, events) |
| assert len(action) == 0 |
|
|
| |
| action = ctrl.step(self._make_clock(10.0), dc, grid, events) |
| assert len(action) == 1 |
| cmd = action[0] |
| assert isinstance(cmd, SetBatchSize) |
| assert cmd.batch_size_by_model["model-a"] == 48 |
|
|
| |
| action = ctrl.step(self._make_clock(15.0), dc, grid, events) |
| assert len(action) == 0 |
|
|
| |
| action = ctrl.step(self._make_clock(20.0), dc, grid, events) |
| assert len(action) == 1 |
| assert isinstance(action[0], SetBatchSize) |
| assert action[0].batch_size_by_model["model-a"] == 32 |
|
|
| def test_multiple_models(self) -> None: |
| schedules = { |
| "model-a": BatchSizeChange(48).at(10), |
| "model-b": BatchSizeChange(64).at(10), |
| } |
| ctrl = BatchSizeScheduleController(schedules=schedules, dt_s=Fraction(1)) |
| dc = MagicMock() |
| grid = MagicMock() |
| events = MagicMock() |
|
|
| action = ctrl.step(self._make_clock(10.0), dc, grid, events) |
| assert len(action) == 1 |
| cmd = action[0] |
| assert isinstance(cmd, SetBatchSize) |
| assert cmd.batch_size_by_model["model-a"] == 48 |
| assert cmd.batch_size_by_model["model-b"] == 64 |
|
|
| def test_ramp_up_rate_per_model(self) -> None: |
| schedule = BatchSizeChange(32, ramp_up_rate=4.0).at(10) |
| ctrl = BatchSizeScheduleController( |
| schedules={"model-a": schedule}, |
| dt_s=Fraction(1), |
| ) |
| dc = MagicMock() |
| grid = MagicMock() |
| events = MagicMock() |
|
|
| action = ctrl.step(self._make_clock(10.0), dc, grid, events) |
| cmd = action[0] |
| assert isinstance(cmd, SetBatchSize) |
| assert cmd.ramp_up_rate_by_model == {"model-a": 4.0} |
|
|
| def test_no_ramp_up_rate_when_zero(self) -> None: |
| schedule = BatchSizeChange(32).at(10) |
| ctrl = BatchSizeScheduleController( |
| schedules={"model-a": schedule}, |
| dt_s=Fraction(1), |
| ) |
| dc = MagicMock() |
| grid = MagicMock() |
| events = MagicMock() |
|
|
| action = ctrl.step(self._make_clock(10.0), dc, grid, events) |
| cmd = action[0] |
| assert isinstance(cmd, SetBatchSize) |
| assert cmd.ramp_up_rate_by_model == {} |
|
|
| def test_mixed_ramp_rates_per_model(self) -> None: |
| schedules = { |
| "model-a": BatchSizeChange(32, ramp_up_rate=4.0).at(10), |
| "model-b": BatchSizeChange(64, ramp_up_rate=8.0).at(10), |
| "model-c": BatchSizeChange(128).at(10), |
| } |
| ctrl = BatchSizeScheduleController(schedules=schedules, dt_s=Fraction(1)) |
| dc = MagicMock() |
| grid = MagicMock() |
| events = MagicMock() |
|
|
| action = ctrl.step(self._make_clock(10.0), dc, grid, events) |
| cmd = action[0] |
| assert isinstance(cmd, SetBatchSize) |
| assert cmd.ramp_up_rate_by_model == {"model-a": 4.0, "model-b": 8.0} |
| assert "model-c" not in cmd.ramp_up_rate_by_model |
|
|
| def test_dt_s(self) -> None: |
| ctrl = BatchSizeScheduleController(schedules={}, dt_s=Fraction(2)) |
| assert ctrl.dt_s == Fraction(2) |
|
|
| def test_empty_schedule(self) -> None: |
| ctrl = BatchSizeScheduleController(schedules={}, dt_s=Fraction(1)) |
| dc = MagicMock() |
| grid = MagicMock() |
| events = MagicMock() |
|
|
| action = ctrl.step(self._make_clock(100.0), dc, grid, events) |
| assert len(action) == 0 |
|
|