| | import pathlib
|
| | from os import path
|
| |
|
| | import torch
|
| | from diffusers import (
|
| | AutoPipelineForText2Image,
|
| | LCMScheduler,
|
| | StableDiffusionPipeline,
|
| | )
|
| |
|
| |
|
| | def load_lcm_weights(
|
| | pipeline,
|
| | use_local_model,
|
| | lcm_lora_id,
|
| | ):
|
| | kwargs = {
|
| | "local_files_only": use_local_model,
|
| | "weight_name": "pytorch_lora_weights.safetensors",
|
| | }
|
| | pipeline.load_lora_weights(
|
| | lcm_lora_id,
|
| | **kwargs,
|
| | adapter_name="lcm",
|
| | )
|
| |
|
| |
|
| | def get_lcm_lora_pipeline(
|
| | base_model_id: str,
|
| | lcm_lora_id: str,
|
| | use_local_model: bool,
|
| | torch_data_type: torch.dtype,
|
| | pipeline_args={},
|
| | ):
|
| | if pathlib.Path(base_model_id).suffix == ".safetensors":
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | if not path.exists(base_model_id):
|
| | raise FileNotFoundError(
|
| | f"Model file not found,Please check your model path: {base_model_id}"
|
| | )
|
| | print("Using single file Safetensors model (Supported models - SD 1.5 models)")
|
| |
|
| | dummy_pipeline = StableDiffusionPipeline.from_single_file(
|
| | base_model_id,
|
| | torch_dtype=torch_data_type,
|
| | safety_checker=None,
|
| | local_files_only=use_local_model,
|
| | use_safetensors=True,
|
| | )
|
| | pipeline = AutoPipelineForText2Image.from_pipe(
|
| | dummy_pipeline,
|
| | **pipeline_args,
|
| | )
|
| | del dummy_pipeline
|
| | else:
|
| | pipeline = AutoPipelineForText2Image.from_pretrained(
|
| | base_model_id,
|
| | torch_dtype=torch_data_type,
|
| | local_files_only=use_local_model,
|
| | **pipeline_args,
|
| | )
|
| |
|
| | load_lcm_weights(
|
| | pipeline,
|
| | use_local_model,
|
| | lcm_lora_id,
|
| | )
|
| |
|
| |
|
| |
|
| | if "lcm" in lcm_lora_id.lower() or "hypersd" in lcm_lora_id.lower():
|
| | print("LCM LoRA model detected so using recommended LCMScheduler")
|
| | pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
|
| |
|
| |
|
| | return pipeline
|
| |
|