| | |
| | |
| |
|
| | import os |
| | from typing import List, Optional |
| |
|
| | import pytest |
| | |
| |
|
| | |
| | WORLD_SIZE_OPTIONS = (1, 2) |
| |
|
| | |
| | |
| |
|
| | |
| |
|
| | |
| | pytest_plugins = [ |
| | |
| | 'tests.fixtures.fixtures', |
| | ] |
| |
|
| |
|
| | def _get_world_size(item: pytest.Item): |
| | """Returns the world_size of a test, defaults to 1.""" |
| | _default = pytest.mark.world_size(1).mark |
| | return item.get_closest_marker('world_size', default=_default).args[0] |
| |
|
| |
|
| | def _get_option( |
| | config: pytest.Config, |
| | name: str, |
| | default: Optional[str] = None, |
| | ) -> str: |
| | val = config.getoption(name) |
| | if val is not None: |
| | assert isinstance(val, str) |
| | return val |
| | val = config.getini(name) |
| | if val == []: |
| | val = None |
| | if val is None: |
| | if default is None: |
| | pytest.fail(f'Config option {name} is not specified but is required',) |
| | val = default |
| | assert isinstance(val, str) |
| | return val |
| |
|
| |
|
| | def _add_option( |
| | parser: pytest.Parser, |
| | name: str, |
| | help: str, |
| | choices: Optional[list[str]] = None, |
| | ): |
| | parser.addoption( |
| | f'--{name}', |
| | default=None, |
| | type=str, |
| | choices=choices, |
| | help=help, |
| | ) |
| | parser.addini( |
| | name=name, |
| | help=help, |
| | type='string', |
| | default=None, |
| | ) |
| |
|
| |
|
| | def pytest_collection_modifyitems( |
| | config: pytest.Config, |
| | items: List[pytest.Item], |
| | ) -> None: |
| | """Filter tests by world_size (for multi-GPU tests)""" |
| | world_size = int(os.environ.get('WORLD_SIZE', '1')) |
| | print(f'world_size={world_size}') |
| |
|
| | conditions = [ |
| | lambda item: _get_world_size(item) == world_size, |
| | ] |
| |
|
| | |
| | remaining = [] |
| | deselected = [] |
| | for item in items: |
| | if all(condition(item) for condition in conditions): |
| | remaining.append(item) |
| | else: |
| | deselected.append(item) |
| |
|
| | if deselected: |
| | config.hook.pytest_deselected(items=deselected) |
| | items[:] = remaining |
| |
|
| |
|
| | def pytest_addoption(parser: pytest.Parser) -> None: |
| | _add_option( |
| | parser, |
| | 'seed', |
| | help="""\ |
| | Rank zero seed to use. `reproducibility.seed_all(seed + dist.get_global_rank())` will be invoked |
| | before each test.""", |
| | ) |
| |
|
| |
|
| | def pytest_sessionfinish(session: pytest.Session, exitstatus: int): |
| | if exitstatus == 5: |
| | session.exitstatus = 0 |
| |
|