File size: 7,099 Bytes
96bb363
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""Tests for BatchSizeChange, BatchSizeSchedule, and BatchSizeScheduleController."""

from __future__ import annotations

from fractions import Fraction
from unittest.mock import MagicMock

import pytest

from openg2g.controller.batch_size_schedule import BatchSizeChange, BatchSizeSchedule, BatchSizeScheduleController
from openg2g.datacenter.command import SetBatchSize


class TestBatchSizeChange:
    def test_basic_construction(self) -> None:
        c = BatchSizeChange(batch_size=64)
        assert c.batch_size == 64
        assert c.ramp_up_rate == 0.0

    def test_with_ramp(self) -> None:
        c = BatchSizeChange(batch_size=32, ramp_up_rate=4.0)
        assert c.batch_size == 32
        assert c.ramp_up_rate == 4.0

    def test_invalid_batch_size(self) -> None:
        with pytest.raises(ValueError, match="batch_size must be positive"):
            BatchSizeChange(batch_size=0)
        with pytest.raises(ValueError, match="batch_size must be positive"):
            BatchSizeChange(batch_size=-1)

    def test_invalid_ramp_up_rate(self) -> None:
        with pytest.raises(ValueError, match="ramp_up_rate must be >= 0"):
            BatchSizeChange(batch_size=32, ramp_up_rate=-1.0)

    def test_at_creates_schedule(self) -> None:
        s = BatchSizeChange(48).at(40.0)
        assert isinstance(s, BatchSizeSchedule)
        assert len(s) == 1
        entries = list(s)
        assert entries[0][0] == 40.0
        assert entries[0][1].batch_size == 48


class TestBatchSizeSchedule:
    def test_pipe_composition(self) -> None:
        s = BatchSizeChange(48).at(40) | BatchSizeChange(32).at(60) | BatchSizeChange(64).at(80)
        assert len(s) == 3
        entries = list(s)
        assert entries[0][0] == 40
        assert entries[1][0] == 60
        assert entries[2][0] == 80

    def test_sorted_by_time(self) -> None:
        s = BatchSizeChange(32).at(60) | BatchSizeChange(48).at(40)
        entries = list(s)
        assert entries[0][0] == 40
        assert entries[1][0] == 60

    def test_bool(self) -> None:
        assert bool(BatchSizeChange(32).at(0))
        assert not bool(BatchSizeSchedule(()))

    def test_repr(self) -> None:
        s = BatchSizeChange(48).at(40) | BatchSizeChange(32, ramp_up_rate=4).at(60)
        r = repr(s)
        assert "BatchSizeChange(48)" in r
        assert "ramp_up_rate=4" in r

    def test_duplicate_timestamps_raises(self) -> None:
        with pytest.raises(ValueError, match="duplicate timestamps"):
            BatchSizeChange(48).at(40) | BatchSizeChange(32).at(40)

    def test_duplicate_timestamps_three_entries(self) -> None:
        with pytest.raises(ValueError, match="duplicate timestamps"):
            BatchSizeChange(48).at(40) | BatchSizeChange(32).at(60) | BatchSizeChange(64).at(40)

    def test_duplicate_timestamps_direct_construction(self) -> None:
        with pytest.raises(ValueError, match="duplicate timestamps"):
            BatchSizeSchedule(((40.0, BatchSizeChange(48)), (40.0, BatchSizeChange(32))))


class TestBatchSizeScheduleController:
    def _make_clock(self, time_s: float) -> MagicMock:
        clock = MagicMock()
        clock.time_s = time_s
        return clock

    def test_emits_at_scheduled_time(self) -> None:
        schedule = BatchSizeChange(48).at(10) | BatchSizeChange(32).at(20)
        ctrl = BatchSizeScheduleController(
            schedules={"model-a": schedule},
            dt_s=Fraction(1),
        )
        dc = MagicMock()
        grid = MagicMock()
        events = MagicMock()

        # Before first event
        action = ctrl.step(self._make_clock(5.0), dc, grid, events)
        assert len(action) == 0

        # At first event
        action = ctrl.step(self._make_clock(10.0), dc, grid, events)
        assert len(action) == 1
        cmd = action[0]
        assert isinstance(cmd, SetBatchSize)
        assert cmd.batch_size_by_model["model-a"] == 48

        # Between events
        action = ctrl.step(self._make_clock(15.0), dc, grid, events)
        assert len(action) == 0

        # At second event
        action = ctrl.step(self._make_clock(20.0), dc, grid, events)
        assert len(action) == 1
        assert isinstance(action[0], SetBatchSize)
        assert action[0].batch_size_by_model["model-a"] == 32

    def test_multiple_models(self) -> None:
        schedules = {
            "model-a": BatchSizeChange(48).at(10),
            "model-b": BatchSizeChange(64).at(10),
        }
        ctrl = BatchSizeScheduleController(schedules=schedules, dt_s=Fraction(1))
        dc = MagicMock()
        grid = MagicMock()
        events = MagicMock()

        action = ctrl.step(self._make_clock(10.0), dc, grid, events)
        assert len(action) == 1
        cmd = action[0]
        assert isinstance(cmd, SetBatchSize)
        assert cmd.batch_size_by_model["model-a"] == 48
        assert cmd.batch_size_by_model["model-b"] == 64

    def test_ramp_up_rate_per_model(self) -> None:
        schedule = BatchSizeChange(32, ramp_up_rate=4.0).at(10)
        ctrl = BatchSizeScheduleController(
            schedules={"model-a": schedule},
            dt_s=Fraction(1),
        )
        dc = MagicMock()
        grid = MagicMock()
        events = MagicMock()

        action = ctrl.step(self._make_clock(10.0), dc, grid, events)
        cmd = action[0]
        assert isinstance(cmd, SetBatchSize)
        assert cmd.ramp_up_rate_by_model == {"model-a": 4.0}

    def test_no_ramp_up_rate_when_zero(self) -> None:
        schedule = BatchSizeChange(32).at(10)
        ctrl = BatchSizeScheduleController(
            schedules={"model-a": schedule},
            dt_s=Fraction(1),
        )
        dc = MagicMock()
        grid = MagicMock()
        events = MagicMock()

        action = ctrl.step(self._make_clock(10.0), dc, grid, events)
        cmd = action[0]
        assert isinstance(cmd, SetBatchSize)
        assert cmd.ramp_up_rate_by_model == {}

    def test_mixed_ramp_rates_per_model(self) -> None:
        schedules = {
            "model-a": BatchSizeChange(32, ramp_up_rate=4.0).at(10),
            "model-b": BatchSizeChange(64, ramp_up_rate=8.0).at(10),
            "model-c": BatchSizeChange(128).at(10),
        }
        ctrl = BatchSizeScheduleController(schedules=schedules, dt_s=Fraction(1))
        dc = MagicMock()
        grid = MagicMock()
        events = MagicMock()

        action = ctrl.step(self._make_clock(10.0), dc, grid, events)
        cmd = action[0]
        assert isinstance(cmd, SetBatchSize)
        assert cmd.ramp_up_rate_by_model == {"model-a": 4.0, "model-b": 8.0}
        assert "model-c" not in cmd.ramp_up_rate_by_model

    def test_dt_s(self) -> None:
        ctrl = BatchSizeScheduleController(schedules={}, dt_s=Fraction(2))
        assert ctrl.dt_s == Fraction(2)

    def test_empty_schedule(self) -> None:
        ctrl = BatchSizeScheduleController(schedules={}, dt_s=Fraction(1))
        dc = MagicMock()
        grid = MagicMock()
        events = MagicMock()

        action = ctrl.step(self._make_clock(100.0), dc, grid, events)
        assert len(action) == 0