File size: 4,151 Bytes
a9536c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import contextlib
import io
import unittest
from unittest import mock

import install


class InstallRequirementTests(unittest.TestCase):
    def test_audio_separator_below_required_version_is_marked_for_install(self):
        with mock.patch("install.check_package", return_value=True), mock.patch(
            "install.get_installed_version",
            return_value="0.41.1",
            create=True,
        ), contextlib.redirect_stdout(io.StringIO()):
            missing = install.check_all("python")

        self.assertIn(
            "audio-separator",
            {info["pip"] for info in missing},
        )

    def test_audio_separator_at_required_version_is_accepted(self):
        with mock.patch("install.check_package", return_value=True), mock.patch(
            "install.get_installed_version",
            return_value="0.44.1",
            create=True,
        ), contextlib.redirect_stdout(io.StringIO()):
            missing = install.check_all("python")

        self.assertNotIn(
            "audio-separator",
            {info["pip"] for info in missing},
        )

    def test_numpy_2_is_marked_for_downgrade(self):
        def version_for_package(_venv_py, distribution_name):
            if distribution_name == "numpy":
                return "2.2.6"
            if distribution_name == "audio-separator":
                return "0.44.1"
            return None

        with mock.patch("install.check_package", return_value=True), mock.patch(
            "install.get_installed_version",
            side_effect=version_for_package,
            create=True,
        ), contextlib.redirect_stdout(io.StringIO()):
            missing = install.check_all("python")

        self.assertIn("numpy<2,>=1.23.0", {info["pip"] for info in missing})

    def test_audio_separator_install_restores_numpy_1_x(self):
        audio_separator_info = install.PACKAGES["audio_separator"]
        calls = []

        def fake_pip_install(_venv_py, package, **kwargs):
            calls.append((package, kwargs))
            return True

        with mock.patch("install.check_all", return_value=[audio_separator_info]), mock.patch(
            "install.detect_cuda_version",
            return_value=None,
        ), mock.patch("install.pip_install", side_effect=fake_pip_install), contextlib.redirect_stdout(
            io.StringIO()
        ):
            ok = install.install_all("python", gpu=False)

        self.assertTrue(ok)
        self.assertEqual(calls[0][0], "audio-separator")
        self.assertEqual(calls[0][1]["version_spec"], ">=0.44.1")
        self.assertIn(("numpy<2,>=1.23.0", {}), calls)

    def test_cuda_13_uses_latest_supported_pytorch_wheel(self):
        def fake_run(cmd, **_kwargs):
            if cmd == ["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"]:
                return mock.Mock(returncode=0, stdout="575.51.03\n")
            if cmd == ["nvidia-smi"]:
                return mock.Mock(returncode=0, stdout="CUDA Version: 13.0\n")
            raise AssertionError(f"Unexpected command: {cmd}")

        with mock.patch("install.subprocess.run", side_effect=fake_run):
            self.assertEqual(
                install.detect_cuda_version(),
                "https://download.pytorch.org/whl/cu126",
            )

    def test_gpu_install_does_not_fall_back_to_cpu_torch(self):
        calls = []

        def fake_pip_install(_venv_py, package, **kwargs):
            calls.append((package, kwargs))
            return True

        with mock.patch(
            "install.check_all",
            return_value=[install.PACKAGES["torch"], install.PACKAGES["torchaudio"]],
        ), mock.patch("install.detect_cuda_version", return_value=None), mock.patch(
            "install.pip_install",
            side_effect=fake_pip_install,
        ), contextlib.redirect_stdout(io.StringIO()):
            ok = install.install_all("python", gpu=True)

        self.assertFalse(ok)
        self.assertEqual(calls, [])


if __name__ == "__main__":
    unittest.main()