| import contextlib | |
| import io | |
| import unittest | |
| from unittest import mock | |
| import install | |
| class InstallRequirementTests(unittest.TestCase): | |
| def test_audio_separator_below_required_version_is_marked_for_install(self): | |
| with mock.patch("install.check_package", return_value=True), mock.patch( | |
| "install.get_installed_version", | |
| return_value="0.41.1", | |
| create=True, | |
| ), contextlib.redirect_stdout(io.StringIO()): | |
| missing = install.check_all("python") | |
| self.assertIn( | |
| "audio-separator", | |
| {info["pip"] for info in missing}, | |
| ) | |
| def test_audio_separator_at_required_version_is_accepted(self): | |
| with mock.patch("install.check_package", return_value=True), mock.patch( | |
| "install.get_installed_version", | |
| return_value="0.44.1", | |
| create=True, | |
| ), contextlib.redirect_stdout(io.StringIO()): | |
| missing = install.check_all("python") | |
| self.assertNotIn( | |
| "audio-separator", | |
| {info["pip"] for info in missing}, | |
| ) | |
| def test_numpy_2_is_marked_for_downgrade(self): | |
| def version_for_package(_venv_py, distribution_name): | |
| if distribution_name == "numpy": | |
| return "2.2.6" | |
| if distribution_name == "audio-separator": | |
| return "0.44.1" | |
| return None | |
| with mock.patch("install.check_package", return_value=True), mock.patch( | |
| "install.get_installed_version", | |
| side_effect=version_for_package, | |
| create=True, | |
| ), contextlib.redirect_stdout(io.StringIO()): | |
| missing = install.check_all("python") | |
| self.assertIn("numpy<2,>=1.23.0", {info["pip"] for info in missing}) | |
| def test_audio_separator_install_restores_numpy_1_x(self): | |
| audio_separator_info = install.PACKAGES["audio_separator"] | |
| calls = [] | |
| def fake_pip_install(_venv_py, package, **kwargs): | |
| calls.append((package, kwargs)) | |
| return True | |
| with mock.patch("install.check_all", return_value=[audio_separator_info]), mock.patch( | |
| "install.detect_cuda_version", | |
| return_value=None, | |
| ), mock.patch("install.pip_install", side_effect=fake_pip_install), contextlib.redirect_stdout( | |
| io.StringIO() | |
| ): | |
| ok = install.install_all("python", gpu=False) | |
| self.assertTrue(ok) | |
| self.assertEqual(calls[0][0], "audio-separator") | |
| self.assertEqual(calls[0][1]["version_spec"], ">=0.44.1") | |
| self.assertIn(("numpy<2,>=1.23.0", {}), calls) | |
| def test_cuda_13_uses_latest_supported_pytorch_wheel(self): | |
| def fake_run(cmd, **_kwargs): | |
| if cmd == ["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"]: | |
| return mock.Mock(returncode=0, stdout="575.51.03\n") | |
| if cmd == ["nvidia-smi"]: | |
| return mock.Mock(returncode=0, stdout="CUDA Version: 13.0\n") | |
| raise AssertionError(f"Unexpected command: {cmd}") | |
| with mock.patch("install.subprocess.run", side_effect=fake_run): | |
| self.assertEqual( | |
| install.detect_cuda_version(), | |
| "https://download.pytorch.org/whl/cu126", | |
| ) | |
| def test_gpu_install_does_not_fall_back_to_cpu_torch(self): | |
| calls = [] | |
| def fake_pip_install(_venv_py, package, **kwargs): | |
| calls.append((package, kwargs)) | |
| return True | |
| with mock.patch( | |
| "install.check_all", | |
| return_value=[install.PACKAGES["torch"], install.PACKAGES["torchaudio"]], | |
| ), mock.patch("install.detect_cuda_version", return_value=None), mock.patch( | |
| "install.pip_install", | |
| side_effect=fake_pip_install, | |
| ), contextlib.redirect_stdout(io.StringIO()): | |
| ok = install.install_all("python", gpu=True) | |
| self.assertFalse(ok) | |
| self.assertEqual(calls, []) | |
| if __name__ == "__main__": | |
| unittest.main() | |