Diffusers
Safetensors
IntrisicWeather-diffusers / convert_forward_renderer.py
BiliSakura's picture
Upload folder using huggingface_hub
c5cfae9 verified
Raw
History Blame Contribute Delete
3.47 kB
#!/usr/bin/env python3
"""Convert GilgameshYX ForwardRenderer into BiliSakura IntrisicWeather-diffusers layout."""
from __future__ import annotations
import json
import shutil
import sys
from pathlib import Path
from diffusers.models.transformers import SD3Transformer2DModel
COLLECTION_ROOT = Path(__file__).resolve().parent
INTRINSIC_REPO = Path("/data/projects/IntrinsicWeather-diffusers")
sys.path.insert(0, str(INTRINSIC_REPO / "src"))
sys.path.insert(0, str(INTRINSIC_REPO))
from scripts._conversion_utils import ( # noqa: E402
expand_sd3_input_projection,
load_torch,
write_scheduler_config,
)
from _collection_setup import install_hub_pipelines # noqa: E402
SD3_PATH = Path(
"/data/projects/Visual-Generative-Foundation-Model-Collection/models/stabilityai/stable-diffusion-3-medium-diffusers"
)
SD35_TRANSFORMER_REPO = "stabilityai/stable-diffusion-3.5-medium"
CKPT_PATH = Path(
"/data/projects/Visual-Generative-Foundation-Model-Collection/models/GilgameshYX/ForwardRenderer"
)
OUTPUT_ROOT = COLLECTION_ROOT
TRANSFORMER_VARIANT = "forward"
SHARED_COMPONENTS = (
"text_encoder",
"text_encoder_2",
"text_encoder_3",
"tokenizer",
"tokenizer_2",
"tokenizer_3",
"vae",
"scheduler",
)
def copy_sd3_shared_components(sd3_path: Path, output_path: Path) -> None:
for name in SHARED_COMPONENTS:
src = sd3_path / name
dst = output_path / name
if dst.exists():
print(f"Skipping existing shared component: {dst}")
continue
print(f"Copying {name} ...")
shutil.copytree(src, dst)
def main() -> None:
transformer_dir = OUTPUT_ROOT / "transformer" / TRANSFORMER_VARIANT
transformer_dir.mkdir(parents=True, exist_ok=True)
print(f"Ensuring shared SD3 components from {SD3_PATH} ...")
copy_sd3_shared_components(SD3_PATH, OUTPUT_ROOT)
write_scheduler_config(OUTPUT_ROOT)
install_hub_pipelines(OUTPUT_ROOT)
print("Converting forward renderer transformer ...")
transformer = SD3Transformer2DModel.from_config(
SD3Transformer2DModel.load_config(SD35_TRANSFORMER_REPO, subfolder="transformer")
)
transformer = expand_sd3_input_projection(transformer, in_channels=96)
transformer.load_state_dict(load_torch(CKPT_PATH / "pytorch_model.bin"), strict=True)
transformer.save_pretrained(transformer_dir.as_posix(), safe_serialization=True)
print("Saving LoRA weights ...")
lora_dir = transformer_dir / "lora"
lora_dir.mkdir(parents=True, exist_ok=True)
shutil.copy2(CKPT_PATH / "pytorch_lora_weights.safetensors", lora_dir / "pytorch_lora_weights.safetensors")
conversion_metadata = {
"task": "forward_renderer",
"transformer_variant": TRANSFORMER_VARIANT,
"source_transformer_checkpoint": str((CKPT_PATH / "pytorch_model.bin").resolve()),
"source_lora_checkpoint": str((CKPT_PATH / "pytorch_lora_weights.safetensors").resolve()),
"lora_dir": str((lora_dir).resolve()),
"sd3_path": str(SD3_PATH.resolve()),
"sd35_transformer_repo": SD35_TRANSFORMER_REPO,
"in_channels": 96,
}
(OUTPUT_ROOT / "conversion_metadata_forward.json").write_text(
json.dumps(conversion_metadata, indent=2) + "\n",
encoding="utf-8",
)
print(f"Saved transformer to: {transformer_dir}")
print("Load with: load_forward_pipeline(transformer_subfolder='forward')")
if __name__ == "__main__":
main()