Spaces:
Running
Running
| import glob | |
| from os import path | |
| from paths import get_file_name, FastStableDiffusionPaths | |
| from pathlib import Path | |
| class _lora_info: | |
| """ | |
| A basic class to keep track of the currently loaded LoRAs and their weights. | |
| The diffusers function _get_active_adapters()_ returns a list of adapter | |
| names but not their weights so we need a way to keep track of the current | |
| LoRA weights to set whenever a new LoRA is loaded. | |
| """ | |
| def __init__( | |
| self, | |
| path: str, | |
| weight: float, | |
| ): | |
| self.path = path | |
| self.adapter_name = get_file_name(path) | |
| self.weight = weight | |
| def __del__(self): | |
| self.path = None | |
| self.adapter_name = None | |
| _loaded_loras = [] | |
| _current_pipeline = None | |
| def load_lora_weight( | |
| pipeline, | |
| lcm_diffusion_setting, | |
| ): | |
| """ | |
| Loads a LoRA from the LoRA path setting. | |
| This function loads a LoRA from the LoRA path stored in the settings so | |
| it's possible to load multiple LoRAs by calling this function more than | |
| once with a different LoRA path setting; note that if you plan to load | |
| multiple LoRAs and dynamically change their weights, you might want to | |
| set the LoRA fuse option to _False_. | |
| """ | |
| if not lcm_diffusion_setting.lora.path: | |
| raise Exception("Empty lora model path") | |
| if not path.exists(lcm_diffusion_setting.lora.path): | |
| raise Exception("Lora model path is invalid") | |
| # If the pipeline has been rebuilt since the last call, remove all | |
| # references to previously loaded LoRAs and store the new pipeline | |
| global _loaded_loras | |
| global _current_pipeline | |
| if pipeline != _current_pipeline: | |
| reset_active_lora_weights() | |
| _current_pipeline = pipeline | |
| current_lora = _lora_info( | |
| lcm_diffusion_setting.lora.path, | |
| lcm_diffusion_setting.lora.weight, | |
| ) | |
| _loaded_loras.append(current_lora) | |
| if lcm_diffusion_setting.lora.enabled: | |
| print(f"LoRA adapter name : {current_lora.adapter_name}") | |
| pipeline.load_lora_weights( | |
| FastStableDiffusionPaths.get_lora_models_path(), | |
| weight_name=Path(lcm_diffusion_setting.lora.path).name, | |
| local_files_only=True, | |
| adapter_name=current_lora.adapter_name, | |
| ) | |
| update_lora_weights( | |
| pipeline, | |
| lcm_diffusion_setting, | |
| ) | |
| if lcm_diffusion_setting.lora.fuse: | |
| pipeline.fuse_lora() | |
| def get_lora_models(root_dir: str): | |
| lora_models = glob.glob(f"{root_dir}/**/*.safetensors", recursive=True) | |
| lora_models_map = {} | |
| for file_path in lora_models: | |
| lora_name = get_file_name(file_path) | |
| if lora_name is not None: | |
| lora_models_map[lora_name] = file_path | |
| return lora_models_map | |
| def get_active_lora_weights(): | |
| """ | |
| Returns a list of _(adapter_name, weight)_ tuples for the currently loaded LoRAs. | |
| """ | |
| active_loras = [] | |
| for lora_info in _loaded_loras: | |
| active_loras.append( | |
| ( | |
| lora_info.adapter_name, | |
| lora_info.weight, | |
| ) | |
| ) | |
| return active_loras | |
| def reset_active_lora_weights(): | |
| """ | |
| Clears the global list of active LoRA weights. | |
| This method clears the list of active LoRA weights but it doesn't actually | |
| remove the active LoRA weights from the current generation pipeline. | |
| This method is only meant to be called when rebuilding the generation pipeline | |
| as it will also clear the _current_pipeline_ variable; setting the | |
| _current_pipeline_ variable to _None_ is safe here since the active LoRA weights | |
| list is being reset, but it also helps to remove the pipeline reference that | |
| might prevent the garbage collector from releasing the current pipeline memory. | |
| """ | |
| global _loaded_loras | |
| for lora in _loaded_loras: | |
| del lora | |
| del _loaded_loras | |
| _loaded_loras = [] | |
| global _current_pipeline | |
| _current_pipeline = None | |
| def update_lora_weights( | |
| pipeline, | |
| lcm_diffusion_setting, | |
| lora_weights=None, | |
| ): | |
| """ | |
| Updates the LoRA weights for the currently active LoRAs. | |
| Args: | |
| pipeline: The currently active pipeline. | |
| lcm_diffusion_setting: The global settings, needed to verify if the | |
| pipeline is running in LCM-LoRA mode. | |
| lora_weights: An optional list of updated _(adapter_name, weight)_ tuples. | |
| """ | |
| global _loaded_loras | |
| global _current_pipeline | |
| if pipeline != _current_pipeline: | |
| print("Wrong pipeline when trying to update LoRA weights") | |
| return | |
| if lora_weights: | |
| for idx, lora in enumerate(lora_weights): | |
| if _loaded_loras[idx].adapter_name != lora[0]: | |
| print("Wrong adapter name in LoRA enumeration!") | |
| continue | |
| _loaded_loras[idx].weight = lora[1] | |
| adapter_names = [] | |
| adapter_weights = [] | |
| if lcm_diffusion_setting.use_lcm_lora: | |
| adapter_names.append("lcm") | |
| adapter_weights.append(1.0) | |
| for lora in _loaded_loras: | |
| adapter_names.append(lora.adapter_name) | |
| adapter_weights.append(lora.weight) | |
| pipeline.set_adapters( | |
| adapter_names, | |
| adapter_weights=adapter_weights, | |
| ) | |
| adapter_weights = zip(adapter_names, adapter_weights) | |
| print(f"Adapters: {list(adapter_weights)}") | |