ImageGen / core /settings.py
RioShiina's picture
Upload folder using huggingface_hub
e62785a verified
import yaml
import os
from collections import OrderedDict
CHECKPOINT_DIR = "models/checkpoints"
LORA_DIR = "models/loras"
EMBEDDING_DIR = "models/embeddings"
CONTROLNET_DIR = "models/controlnet"
MODEL_PATCHES_DIR = "models/model_patches"
DIFFUSION_MODELS_DIR = "models/diffusion_models"
VAE_DIR = "models/vae"
TEXT_ENCODERS_DIR = "models/text_encoders"
STYLE_MODELS_DIR = "models/style_models"
CLIP_VISION_DIR = "models/clip_vision"
IPADAPTER_DIR = "models/ipadapter"
IPADAPTER_FLUX_DIR = "models/ipadapter-flux"
INPUT_DIR = "input"
OUTPUT_DIR = "output"
CATEGORY_TO_DIR_MAP = {
"diffusion_models": DIFFUSION_MODELS_DIR,
"text_encoders": TEXT_ENCODERS_DIR,
"vae": VAE_DIR,
"checkpoints": CHECKPOINT_DIR,
"loras": LORA_DIR,
"controlnet": CONTROLNET_DIR,
"model_patches": MODEL_PATCHES_DIR,
"embeddings": EMBEDDING_DIR,
"style_models": STYLE_MODELS_DIR,
"clip_vision": CLIP_VISION_DIR,
"ipadapter": IPADAPTER_DIR,
"ipadapter-flux": IPADAPTER_FLUX_DIR
}
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
_MODEL_LIST_PATH = os.path.join(_PROJECT_ROOT, 'yaml', 'model_list.yaml')
_FILE_LIST_PATH = os.path.join(_PROJECT_ROOT, 'yaml', 'file_list.yaml')
_IPADAPTER_LIST_PATH = os.path.join(_PROJECT_ROOT, 'yaml', 'ipadapter.yaml')
_CONSTANTS_PATH = os.path.join(_PROJECT_ROOT, 'yaml', 'constants.yaml')
_MODEL_ARCHITECTURES_PATH = os.path.join(_PROJECT_ROOT, 'yaml', 'model_architectures.yaml')
_IMAGE_GEN_FEATURES_PATH = os.path.join(_PROJECT_ROOT, 'yaml', 'image_gen_features.yaml')
_MODEL_DEFAULTS_PATH = os.path.join(_PROJECT_ROOT, 'yaml', 'model_defaults.yaml')
def load_constants_from_yaml(filepath=_CONSTANTS_PATH):
if not os.path.exists(filepath):
print(f"Warning: Constants file not found at {filepath}. Using fallback values.")
return {}
with open(filepath, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
def load_architectures_config(filepath=_MODEL_ARCHITECTURES_PATH):
if not os.path.exists(filepath):
print(f"Warning: Architectures file not found at {filepath}.")
return {}
with open(filepath, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
def load_features_config(filepath=_IMAGE_GEN_FEATURES_PATH):
if not os.path.exists(filepath):
print(f"Warning: Features file not found at {filepath}.")
return {}
with open(filepath, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
def load_model_defaults(filepath=_MODEL_DEFAULTS_PATH):
if not os.path.exists(filepath):
print(f"Warning: Model defaults file not found at {filepath}.")
return {}
with open(filepath, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
def load_file_download_map(filepath=_FILE_LIST_PATH):
if not os.path.exists(filepath):
raise FileNotFoundError(f"The file list (for downloads) was not found at: {filepath}")
with open(filepath, 'r', encoding='utf-8') as f:
file_list_data = yaml.safe_load(f)
download_info_map = {}
for category, files in file_list_data.get('file', {}).items():
if isinstance(files, list):
for file_info in files:
if 'filename' in file_info:
file_info['category'] = category
download_info_map[file_info['filename']] = file_info
return download_info_map
def load_models_from_yaml(model_list_filepath=_MODEL_LIST_PATH, download_map=None):
if not os.path.exists(model_list_filepath):
raise FileNotFoundError(f"The model list file was not found at: {model_list_filepath}")
if download_map is None:
raise ValueError("download_map must be provided to load_models_from_yaml")
with open(model_list_filepath, 'r', encoding='utf-8') as f:
model_data = yaml.safe_load(f)
model_maps = {
"MODEL_MAP_CHECKPOINT": OrderedDict(),
"ALL_MODEL_MAP": OrderedDict(),
}
category_map_names = {
"Checkpoint": "MODEL_MAP_CHECKPOINT",
"Checkpoints": "MODEL_MAP_CHECKPOINT"
}
for category, architectures in model_data.items():
if category in category_map_names:
map_name = category_map_names[category]
if not isinstance(architectures, dict): continue
for arch, arch_data in architectures.items():
if not isinstance(arch_data, dict): continue
latent_type = arch_data.get('latent_type', 'latent')
models = arch_data.get('models', [])
if not isinstance(models, list): continue
for model in models:
display_name = model['display_name']
path_or_components = model.get('path') or model.get('components')
mod_category = model.get('category', None)
repo_id = ''
if isinstance(path_or_components, str):
download_info = download_map.get(path_or_components, {})
repo_id = download_info.get('repo_id', '')
model_tuple = (
repo_id,
path_or_components,
arch,
latent_type,
mod_category
)
model_maps[map_name][display_name] = model_tuple
model_maps["ALL_MODEL_MAP"][display_name] = model_tuple
return model_maps
try:
ALL_FILE_DOWNLOAD_MAP = load_file_download_map()
loaded_maps = load_models_from_yaml(download_map=ALL_FILE_DOWNLOAD_MAP)
MODEL_MAP_CHECKPOINT = loaded_maps["MODEL_MAP_CHECKPOINT"]
ALL_MODEL_MAP = loaded_maps["ALL_MODEL_MAP"]
category_to_model_type = {
"diffusion_models": "UNET",
"text_encoders": "TEXT_ENCODER",
"vae": "VAE",
"checkpoints": "SDXL",
"loras": "LORA",
"controlnet": "CONTROLNET",
"model_patches": "MODEL_PATCH",
"style_models": "STYLE",
"clip_vision": "CLIP_VISION",
"ipadapter": "IPADAPTER",
"ipadapter-flux": "IPADAPTER_FLUX"
}
for filename, file_info in ALL_FILE_DOWNLOAD_MAP.items():
if filename not in ALL_MODEL_MAP:
category = file_info.get('category')
model_type = category_to_model_type.get(category, 'UNKNOWN')
repo_id = file_info.get('repo_id', '')
ALL_MODEL_MAP[filename] = (repo_id, filename, model_type, None, None)
MODEL_TYPE_MAP = {k: v[2] for k, v in ALL_MODEL_MAP.items()}
ARCH_CATEGORIES_MAP = {}
for display_name, info in MODEL_MAP_CHECKPOINT.items():
arch = info[2]
cat = info[4] if len(info) > 4 else None
if arch not in ARCH_CATEGORIES_MAP:
ARCH_CATEGORIES_MAP[arch] = []
if cat and cat not in ARCH_CATEGORIES_MAP[arch]:
ARCH_CATEGORIES_MAP[arch].append(cat)
except Exception as e:
print(f"FATAL: Could not load model configuration from YAML. Error: {e}")
ALL_FILE_DOWNLOAD_MAP = {}
MODEL_MAP_CHECKPOINT, ALL_MODEL_MAP = {}, {}
MODEL_TYPE_MAP = {}
ARCH_CATEGORIES_MAP = {}
try:
_constants = load_constants_from_yaml()
MAX_LORAS = _constants.get('MAX_LORAS', 5)
MAX_EMBEDDINGS = _constants.get('MAX_EMBEDDINGS', 5)
MAX_CONDITIONINGS = _constants.get('MAX_CONDITIONINGS', 10)
MAX_CONTROLNETS = _constants.get('MAX_CONTROLNETS', 5)
MAX_IPADAPTERS = _constants.get('MAX_IPADAPTERS', 5)
LORA_SOURCE_CHOICES = _constants.get('LORA_SOURCE_CHOICES', ["Civitai", "File"])
RESOLUTION_MAP = _constants.get('RESOLUTION_MAP', {})
ARCHITECTURES_CONFIG = load_architectures_config()
FEATURES_CONFIG = load_features_config()
MODEL_DEFAULTS_CONFIG = load_model_defaults()
except Exception as e:
print(f"FATAL: Could not load constants from YAML. Error: {e}")
MAX_LORAS, MAX_EMBEDDINGS, MAX_CONDITIONINGS, MAX_CONTROLNETS, MAX_IPADAPTERS = 5, 5, 10, 5, 5
LORA_SOURCE_CHOICES = ["Civitai", "File"]
RESOLUTION_MAP = {}
ARCHITECTURES_CONFIG = {}
FEATURES_CONFIG = {}
MODEL_DEFAULTS_CONFIG = {}