Spaces:
No application file
No application file
| import json | |
| import os.path | |
| import shlex | |
| import subprocess | |
| from enum import Enum | |
| from setup_tools.os import is_windows | |
| extension_states = os.path.join('data', 'extensions.json') | |
| ext_folder = os.path.join('extensions') | |
| def git_ready(): | |
| cmd = 'git --version' | |
| cmd = cmd if is_windows() else shlex.split(cmd) | |
| result = subprocess.run(cmd, capture_output=True).returncode | |
| return result == 0 | |
| class UpdateStatus(Enum): | |
| no_git = -1 | |
| unmanaged = 0 | |
| updated = 1 | |
| outdated = 2 | |
| class Extension: | |
| def __init__(self, ext_name, load_states): | |
| self.enabled = (ext_name not in load_states.keys()) or load_states[ext_name] | |
| self.extname = ext_name | |
| # self.abspath = os.path.abspath(os.path.join(ext_folder, ext_name)) | |
| self.path = os.path.join(ext_folder, ext_name) | |
| self.main_file = os.path.join(self.path, 'main.py') | |
| self.req_file = os.path.join(self.path, 'requirements.py') # Optional | |
| self.style_file = os.path.join(self.path, 'style.py') | |
| self.js_file = os.path.join(self.path, 'scripts', 'script.js') | |
| self.git_dir = os.path.join(self.path, '.git') | |
| self.update_el = None | |
| extinfo = os.path.join(self.path, 'extension.json') | |
| if os.path.isfile(extinfo): | |
| with open(extinfo, 'r', encoding='utf8') as info_file: | |
| self.info = json.load(info_file) | |
| for k in ['name', 'description', 'author']: | |
| if k not in self.info: | |
| self.info[k] = 'Not provided' | |
| if 'tags' not in self.info: | |
| self.info['tags'] = [] | |
| else: | |
| raise FileNotFoundError(f'No extension.json file for {ext_name} extension.') | |
| def activate(self): | |
| if self.enabled and os.path.isfile(self.main_file): | |
| __import__(os.path.splitext(self.main_file)[0].replace(os.path.sep, '.'), fromlist=['']) | |
| def get_style_rules(self): | |
| if self.enabled and os.path.isfile(self.style_file): | |
| __import__(os.path.splitext(self.style_file)[0].replace(os.path.sep, '.'), fromlist=['']) | |
| def get_requirements(self): | |
| if self.enabled and os.path.isfile(self.req_file): | |
| return __import__(os.path.splitext(self.req_file)[0].replace(os.path.sep, '.'), fromlist=['']).requirements() | |
| return [] | |
| def get_javascript(self) -> str | bool: | |
| if self.enabled and os.path.isfile(self.js_file): | |
| return self.js_file | |
| return False | |
| def set_enabled(self, new): | |
| self.enabled = new | |
| set_load_states() | |
| try: | |
| import gradio | |
| return gradio.update(value=new) | |
| except: | |
| return new | |
| def check_updates(self) -> UpdateStatus: | |
| if not os.path.isdir(self.git_dir): | |
| return UpdateStatus.unmanaged | |
| command1 = 'git fetch' | |
| command1 = command1 if is_windows() else shlex.split(command1) | |
| command2 = 'git status -uno' | |
| command2 = command2 if is_windows() else shlex.split(command2) | |
| search_string = 'git pull' # Included in message from git if not up to date | |
| neg_search_string = 'Your branch is up to date' | |
| a = subprocess.run(command1, capture_output=True, cwd=self.path) | |
| if a.returncode != 0: | |
| return UpdateStatus.no_git | |
| b = subprocess.run(command2, capture_output=True, cwd=self.path) | |
| if a.returncode != 0: | |
| return UpdateStatus.no_git | |
| out_string = b.stdout.decode() | |
| if search_string in out_string: | |
| return UpdateStatus.outdated | |
| if neg_search_string in out_string: | |
| return UpdateStatus.updated | |
| return UpdateStatus.outdated | |
| def update(self): | |
| if not os.path.isdir(self.git_dir): | |
| return | |
| command = 'git pull' | |
| command = command if is_windows() else shlex.split(command) | |
| output = subprocess.run(command, capture_output=True, cwd=self.path) | |
| if output.returncode != 0: | |
| print(f'Something went wrong during git pull for {self.extname}') | |
| def get_valid_extensions(): | |
| return [e for e in os.listdir(ext_folder) | |
| if os.path.isdir(os.path.join(ext_folder, e)) | |
| and os.path.isfile(os.path.join(ext_folder, e, 'extension.json'))] | |
| states: dict[str, Extension] = {} | |
| def set_load_states(): | |
| s = {k: v.enabled for k, v in zip(states.keys(), states.values())} | |
| json.dump(s, open(extension_states, 'w', encoding='utf8')) | |
| def get_load_states(): | |
| if os.path.isfile(extension_states): | |
| return json.load(open(extension_states, 'r', encoding='utf8')) | |
| return {} | |
| register_callbacks = [ | |
| 'webui.init', | |
| 'webui.settings', | |
| 'webui.tabs', | |
| 'webui.tabs.utils', | |
| 'webui.tts.list' | |
| ] | |
| def init_extensions(): | |
| # Register default callbacks | |
| from webui.extensionlib.callbacks import register_new as register | |
| for cb in register_callbacks: | |
| register(cb) | |
| # Load enabled extensions | |
| s = get_load_states() | |
| exts = get_valid_extensions() | |
| print(f'Found extensions: {", ".join(exts)}') | |
| for ext in exts: | |
| states[ext] = Extension(ext, s) | |
| def get_scripts() -> list[str]: | |
| out = [] | |
| for script in [e.get_javascript() for e in states.values()]: | |
| if script: | |
| out.append(script) | |
| return out | |
| def get_requirements(): | |
| out = [] | |
| for req in [e.get_requirements() for e in states.values()]: | |
| if req: | |
| out += req | |
| return out | |