live / tests /test_controller_generic_contract.py
github-actions[bot]
deploy: sync from GitHub 2026-04-18T00:48:45Z
96bb363
from __future__ import annotations
import types
from fractions import Fraction
from typing import Any, cast
import pytest
from openg2g.clock import SimulationClock
from openg2g.common import ThreePhase
from openg2g.controller.base import Controller
from openg2g.datacenter.base import DatacenterBackend, DatacenterState, LLMBatchSizeControlledDatacenter
from openg2g.datacenter.command import DatacenterCommand
from openg2g.datacenter.offline import OfflineDatacenterState
from openg2g.events import EventEmitter
from openg2g.grid.base import BusVoltages, GridBackend, GridState, PhaseVoltages
from openg2g.grid.command import GridCommand
from openg2g.grid.opendss import OpenDSSGrid
class _DC(DatacenterBackend[DatacenterState]):
def __init__(self) -> None:
super().__init__()
@property
def dt_s(self) -> Fraction:
return Fraction(1)
def reset(self) -> None:
pass
def step(self, clock: SimulationClock, events: EventEmitter) -> DatacenterState:
return DatacenterState(time_s=clock.time_s, power_w=ThreePhase(a=1.0, b=1.0, c=1.0))
def apply_control(self, command: DatacenterCommand, events: EventEmitter) -> None:
pass
class _Grid(GridBackend[GridState]):
def __init__(self) -> None:
super().__init__()
@property
def dt_s(self) -> Fraction:
return Fraction(1)
def reset(self) -> None:
pass
@property
def v_index(self) -> list[tuple[str, int]]:
return [("671", 1)]
def step(
self,
clock: SimulationClock,
power_samples_w: list[ThreePhase],
events: EventEmitter,
) -> GridState:
return GridState(
time_s=clock.time_s,
voltages=BusVoltages({"671": PhaseVoltages(1.0, 1.0, 1.0)}),
)
def apply_control(self, command: GridCommand, events: EventEmitter) -> None:
pass
def voltages_vector(self):
import numpy as np
return np.array([1.0], dtype=float)
def estimate_sensitivity(self, perturbation_kw: float = 100.0):
import numpy as np
return np.zeros((1, 3), dtype=float), np.array([1.0], dtype=float)
def test_controller_requires_explicit_generic_parameters():
controller_any = cast(Any, Controller)
with pytest.raises(TypeError, match="must explicitly specialize Controller generics"):
type(
"_BadMissingGeneric",
(controller_any,),
{
"dt_s": property(lambda self: Fraction(1)),
"step": lambda self, clock, datacenter, grid, events: [],
},
)
def test_controller_rejects_reversed_generic_order():
reversed_base = Controller.__class_getitem__((OpenDSSGrid, DatacenterBackend)) # type: ignore[unresolved-attribute]
with pytest.raises(TypeError, match="is not a subclass of DatacenterBackend"):
types.new_class(
"_BadReversed",
(reversed_base,),
exec_body=lambda ns: ns.update(
{
"dt_s": property(lambda self: Fraction(1)),
"step": lambda self, clock, datacenter, grid, events: [],
}
),
)
def test_controller_rejects_random_classes_in_generics():
class _Random:
pass
random_base = Controller.__class_getitem__((_Random, OpenDSSGrid)) # type: ignore[unresolved-attribute]
with pytest.raises(TypeError, match="is not a subclass of DatacenterBackend"):
types.new_class(
"_BadRandom",
(random_base,),
exec_body=lambda ns: ns.update(
{
"dt_s": property(lambda self: Fraction(1)),
"step": lambda self, clock, datacenter, grid, events: [],
}
),
)
def test_controller_rejects_non_abc_subclass_for_grid_generic():
class _NotGrid:
pass
bad_grid_base = Controller.__class_getitem__((_DC, _NotGrid)) # type: ignore[unresolved-attribute]
with pytest.raises(TypeError, match="is not a subclass of GridBackend"):
types.new_class(
"_BadGrid",
(bad_grid_base,),
exec_body=lambda ns: ns.update(
{
"dt_s": property(lambda self: Fraction(1)),
"step": lambda self, clock, datacenter, grid, events: [],
}
),
)
def test_controller_inherits_compatibility_from_typed_parent():
class _BaseTyped(Controller[_DC, _Grid]):
@property
def dt_s(self) -> Fraction:
return Fraction(1)
def reset(self) -> None:
pass
def step(
self,
clock: SimulationClock,
datacenter: _DC,
grid: _Grid,
events: EventEmitter,
) -> list[DatacenterCommand | GridCommand]:
return []
class _Child(_BaseTyped):
pass
assert _Child.compatible_datacenter_types() == (_DC,)
assert _Child.compatible_grid_types() == (_Grid,)
def test_controller_accepts_parameterized_backend_generics():
"""Controller[DatacenterBackend[X], GridBackend[Y]] should work via get_origin fallback."""
class _Parameterized(Controller[LLMBatchSizeControlledDatacenter[OfflineDatacenterState], GridBackend[GridState]]):
@property
def dt_s(self) -> Fraction:
return Fraction(1)
def reset(self) -> None:
pass
def step(
self,
clock: SimulationClock,
datacenter: LLMBatchSizeControlledDatacenter[OfflineDatacenterState],
grid: GridBackend[GridState],
events: EventEmitter,
) -> list[DatacenterCommand | GridCommand]:
return []
assert _Parameterized.compatible_datacenter_types() == (LLMBatchSizeControlledDatacenter,)
assert _Parameterized.compatible_grid_types() == (GridBackend,)