| import logging |
| import os |
| import shutil |
| from app.logger import log_startup_warning |
| from utils.install_util import get_missing_requirements_message |
| from comfy.cli_args import args |
|
|
| _DB_AVAILABLE = False |
| Session = None |
|
|
|
|
| try: |
| from alembic import command |
| from alembic.config import Config |
| from alembic.runtime.migration import MigrationContext |
| from alembic.script import ScriptDirectory |
| from sqlalchemy import create_engine |
| from sqlalchemy.orm import sessionmaker |
|
|
| _DB_AVAILABLE = True |
| except ImportError as e: |
| log_startup_warning( |
| f""" |
| ------------------------------------------------------------------------ |
| Error importing dependencies: {e} |
| {get_missing_requirements_message()} |
| This error is happening because ComfyUI now uses a local sqlite database. |
| ------------------------------------------------------------------------ |
| """.strip() |
| ) |
|
|
|
|
| def dependencies_available(): |
| """ |
| Temporary function to check if the dependencies are available |
| """ |
| return _DB_AVAILABLE |
|
|
|
|
| def can_create_session(): |
| """ |
| Temporary function to check if the database is available to create a session |
| During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created |
| """ |
| return dependencies_available() and Session is not None |
|
|
|
|
| def get_alembic_config(): |
| root_path = os.path.join(os.path.dirname(__file__), "../..") |
| config_path = os.path.abspath(os.path.join(root_path, "alembic.ini")) |
| scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db")) |
|
|
| config = Config(config_path) |
| config.set_main_option("script_location", scripts_path) |
| config.set_main_option("sqlalchemy.url", args.database_url) |
|
|
| return config |
|
|
|
|
| def get_db_path(): |
| url = args.database_url |
| if url.startswith("sqlite:///"): |
| return url.split("///")[1] |
| else: |
| raise ValueError(f"Unsupported database URL '{url}'.") |
|
|
|
|
| def init_db(): |
| db_url = args.database_url |
| logging.debug(f"Database URL: {db_url}") |
| db_path = get_db_path() |
| db_exists = os.path.exists(db_path) |
|
|
| config = get_alembic_config() |
|
|
| |
| engine = create_engine(db_url) |
| conn = engine.connect() |
|
|
| context = MigrationContext.configure(conn) |
| current_rev = context.get_current_revision() |
|
|
| script = ScriptDirectory.from_config(config) |
| target_rev = script.get_current_head() |
|
|
| if target_rev is None: |
| logging.warning("No target revision found.") |
| elif current_rev != target_rev: |
| |
| backup_path = db_path + ".bkp" |
| if db_exists: |
| shutil.copy(db_path, backup_path) |
| else: |
| backup_path = None |
|
|
| try: |
| command.upgrade(config, target_rev) |
| logging.info(f"Database upgraded from {current_rev} to {target_rev}") |
| except Exception as e: |
| if backup_path: |
| |
| shutil.copy(backup_path, db_path) |
| os.remove(backup_path) |
| logging.exception("Error upgrading database: ") |
| raise e |
|
|
| global Session |
| Session = sessionmaker(bind=engine) |
|
|
|
|
| def create_session(): |
| return Session() |
|
|