Spaces:
No application file
No application file
| import re | |
| import shlex | |
| import subprocess | |
| import sys | |
| import time | |
| from enum import Enum | |
| import webui.args | |
| from autodebug.autodebug import InstallFailException | |
| from setup_tools.os import is_windows | |
| from threading import Thread | |
| valid_last: list[tuple[str, str]] = None | |
| class CompareAction(Enum): | |
| LT = -2 | |
| LEQ = -1 | |
| EQ = 0 | |
| GEQ = 1 | |
| GT = 2 | |
| class Requirement: | |
| def __init__(self): | |
| self.running = False | |
| def install_or_upgrade_if_needed(self): | |
| if not self.is_installed() or not self.is_right_version(): | |
| self.post_install(self.install()) | |
| def post_install(self, install_output: tuple[int, str, str]): | |
| exit_code, stdout, stderr = install_output | |
| if exit_code != 0: | |
| raise InstallFailException(exit_code, stdout, stderr) | |
| def is_right_version(self): | |
| raise NotImplementedError('Not implemented') | |
| def is_installed(self): | |
| raise NotImplementedError('Not implemented') | |
| def install_check(self, package_name: str) -> bool: | |
| return self.get_package_version(package_name) is not False | |
| def loading_thread(status_dict, name): | |
| idx = 0 | |
| load_symbols = ['|', '/', '-', '\\'] | |
| while status_dict['running']: | |
| curr_symbol = load_symbols[idx % len(load_symbols)] | |
| idx += 1 | |
| print(f'\rInstalling {name} {curr_symbol}', end='') | |
| time.sleep(0.25) | |
| print(f'\rInstalled {name}! ' if status_dict[ | |
| 'success'] else f'\rFailed to install {name}. Check AutoDebug output.') | |
| def install_pip(self, command, name=None) -> tuple[int, str, str]: | |
| global valid_last | |
| valid_last = None | |
| if not name: | |
| name = command | |
| status_dict = { | |
| 'running': True | |
| } | |
| verbose = webui.args.args.verbose | |
| if not verbose: | |
| thread = Thread(target=self.loading_thread, args=[status_dict, name], daemon=True) | |
| thread.start() | |
| args = f'"{sys.executable}" -m pip install --upgrade {command}' | |
| args = args if self.is_windows() else shlex.split(args) | |
| result = subprocess.run(args, capture_output=not verbose, text=True) | |
| status_dict['success'] = result.returncode == 0 | |
| status_dict['running'] = False | |
| if not verbose: | |
| while thread.is_alive(): | |
| time.sleep(0.1) | |
| return result.returncode, result.stdout, result.stderr | |
| def is_windows(self) -> bool: | |
| return is_windows() | |
| def install(self) -> tuple[int, str, str]: | |
| raise NotImplementedError('Not implemented') | |
| def pip_freeze(self) -> list[tuple[str, str]]: | |
| global valid_last | |
| if valid_last: | |
| return valid_last | |
| args = f'"{sys.executable}" -m pip freeze' | |
| args = args if self.is_windows() else shlex.split(args) | |
| result = subprocess.run(args, capture_output=True, text=True) | |
| test_str = result.stdout | |
| out_list = [] | |
| matches = re.finditer('^(.*)(?:==| @ )(.+)$', test_str, re.MULTILINE) | |
| for match in matches: | |
| out_list.append((match.group(1), match.group(2))) | |
| valid_last = out_list | |
| return out_list | |
| def get_package_version(self, name: str, freeze: dict[tuple[str, str]] | None = None) -> bool | str: | |
| if freeze is None: | |
| freeze = self.pip_freeze() | |
| for p_name, version in freeze: | |
| if name.casefold() == p_name.casefold(): | |
| return version | |
| return False | |
| class SimpleRequirement(Requirement): | |
| package_name: str | |
| def is_right_version(self): | |
| return True | |
| def is_installed(self): | |
| return self.install_check(self.package_name) | |
| def install(self) -> tuple[int, str, str]: | |
| return self.install_pip(self.package_name) | |
| class SimpleRequirementInit(SimpleRequirement): | |
| def __init__(self, package_name, compare: CompareAction = None, version: str = None): | |
| super().__init__() | |
| self.package_name = package_name | |
| self.compare = compare | |
| self.version = version | |
| def is_right_version(self): | |
| if self.compare is None or self.version is None: | |
| return True | |
| from packaging import version | |
| version_obj = version.parse(self.get_package_version(self.package_name)) | |
| version_target_obj = version.parse(self.version) | |
| match self.compare: | |
| case CompareAction.LT: | |
| return version_obj < version_target_obj | |
| case CompareAction.LEQ: | |
| return version_obj <= version_target_obj | |
| case CompareAction.EQ: | |
| return version_obj == version_target_obj | |
| case CompareAction.GEQ: | |
| return version_obj >= version_target_obj | |
| case CompareAction.GT: | |
| return version_obj > version_target_obj | |
| case _: | |
| return True | |
| def install(self) -> tuple[int, str, str]: | |
| if self.version is None: | |
| return self.install_pip(self.package_name) | |
| match self.compare: | |
| case CompareAction.LT: | |
| symbol = '<' | |
| case CompareAction.LEQ: | |
| symbol = '<=' | |
| case CompareAction.EQ: | |
| symbol = '==' | |
| case CompareAction.GEQ: | |
| symbol = '>=' | |
| case CompareAction.GT: | |
| symbol = '>' | |
| case _: | |
| symbol = '==' | |
| return self.install_pip(f'{self.package_name}{symbol}{self.version}', self.package_name) | |
| class SimpleGitRequirement(SimpleRequirement): | |
| def __init__(self, package_name, repo, check_version=False): | |
| super().__init__() | |
| self.package_name = package_name | |
| self.repo = repo | |
| self.check_version = check_version | |
| def is_right_version(self): | |
| if not self.check_version: | |
| return True | |
| return self.get_package_version(self.package_name) == self.repo | |
| def install(self) -> tuple[int, str, str]: | |
| return self.install_pip(self.repo, self.package_name) | |