| """Thin wrapper around shared gpu_utils helpers.""" |
|
|
| from importlib.util import module_from_spec, spec_from_file_location |
| from pathlib import Path |
|
|
| _PROJECT_ROOT = Path(__file__).resolve().parents[3] |
| _SHARED_GPU_UTILS = _PROJECT_ROOT / "utils" / "gpu_utils.py" |
|
|
| _spec = spec_from_file_location( |
| "wavegen_shared_gpu_utils", _SHARED_GPU_UTILS |
| ) |
| if _spec is None or _spec.loader is None: |
| raise ModuleNotFoundError( |
| f"Shared gpu_utils module not found at {_SHARED_GPU_UTILS}" |
| ) |
|
|
| _module = module_from_spec(_spec) |
| _spec.loader.exec_module(_module) |
|
|
| DEFAULT_THRESHOLD_MB = _module.DEFAULT_THRESHOLD_MB |
| query_gpu_memory = _module.query_gpu_memory |
| select_gpus = _module.select_gpus |
|
|
| __all__ = [ |
| "DEFAULT_THRESHOLD_MB", |
| "query_gpu_memory", |
| "select_gpus", |
| ] |
|
|