| | |
| | """ |
| | Convert a Flux model from Diffusers (folder or single-file) into the original |
| | single-file Flux transformer checkpoint used by Black Forest Labs / ComfyUI. |
| | |
| | Input : /path/to/diffusers (root or .../transformer) OR /path/to/*.safetensors (single file) |
| | Output : /path/to/flux1-your-model.safetensors (transformer only) |
| | |
| | Usage: |
| | python diffusers_to_flux_transformer.py /path/to/diffusers /out/flux1-dev.safetensors |
| | python diffusers_to_flux_transformer.py /path/to/diffusion_pytorch_model.safetensors /out/flux1-dev.safetensors |
| | # optional quantization: |
| | # --fp8 (float8_e4m3fn, simple) |
| | # --fp8-scaled (scaled float8 for 2D weights; adds .scale_weight tensors) |
| | """ |
| |
|
| | import argparse |
| | import json |
| | from pathlib import Path |
| | from collections import OrderedDict |
| |
|
| | import torch |
| | from safetensors import safe_open |
| | import safetensors.torch |
| | from tqdm import tqdm |
| |
|
| |
|
| | def parse_args(): |
| | ap = argparse.ArgumentParser() |
| | ap.add_argument("diffusers_path", type=str, |
| | help="Path to Diffusers checkpoint folder OR a single .safetensors file.") |
| | ap.add_argument("output_path", type=str, |
| | help="Output .safetensors path for the Flux transformer.") |
| | ap.add_argument("--fp8", action="store_true", |
| | help="Experimental: write weights as float8_e4m3fn via stochastic rounding (transformer only).") |
| | ap.add_argument("--fp8-scaled", action="store_true", |
| | help="Experimental: scaled float8_e4m3fn for 2D weight tensors; adds .scale_weight tensors.") |
| | return ap.parse_args() |
| |
|
| |
|
| | |
| | DIFFUSERS_MAP = { |
| | |
| | "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"], |
| | "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"], |
| | "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"], |
| | "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"], |
| |
|
| | "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"], |
| | "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"], |
| | "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"], |
| | "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"], |
| |
|
| | "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"], |
| | "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"], |
| | "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"], |
| | "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"], |
| |
|
| | "txt_in.weight": ["context_embedder.weight"], |
| | "txt_in.bias": ["context_embedder.bias"], |
| | "img_in.weight": ["x_embedder.weight"], |
| | "img_in.bias": ["x_embedder.bias"], |
| |
|
| | |
| | "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"], |
| | "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"], |
| | "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"], |
| | "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"], |
| |
|
| | "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"], |
| | "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"], |
| | "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"], |
| | "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"], |
| |
|
| | "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"], |
| | "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"], |
| | "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"], |
| | "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"], |
| |
|
| | "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"], |
| | "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"], |
| | "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"], |
| | "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"], |
| |
|
| | "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"], |
| | "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"], |
| | "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"], |
| | "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"], |
| |
|
| | "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"], |
| | "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"], |
| | "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"], |
| | "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"], |
| |
|
| | |
| | "single_blocks.().modulation.lin.weight": ["norm.linear.weight"], |
| | "single_blocks.().modulation.lin.bias": ["norm.linear.bias"], |
| | "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"], |
| | "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"], |
| | "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"], |
| | "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"], |
| | "single_blocks.().linear2.weight": ["proj_out.weight"], |
| | "single_blocks.().linear2.bias": ["proj_out.bias"], |
| |
|
| | |
| | "final_layer.linear.weight": ["proj_out.weight"], |
| | "final_layer.linear.bias": ["proj_out.bias"], |
| | |
| | "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"], |
| | "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"], |
| | } |
| |
|
| |
|
| | class DiffusersSource: |
| | """ |
| | Uniform interface over: |
| | 1) Folder with index JSON + shards |
| | 2) Folder with exactly one .safetensors (no index) |
| | 3) Single .safetensors file |
| | Provides .has(key), .get(key)->Tensor, .base_keys (keys with 'model.' stripped for scanning) |
| | """ |
| |
|
| | POSSIBLE_PREFIXES = ["", "model."] |
| |
|
| | def __init__(self, path: Path): |
| | p = Path(path) |
| | if p.is_dir(): |
| | |
| | if (p / "transformer").is_dir(): |
| | p = p / "transformer" |
| | self._init_from_dir(p) |
| | elif p.is_file() and p.suffix == ".safetensors": |
| | self._init_from_single_file(p) |
| | else: |
| | raise FileNotFoundError(f"Invalid path: {p}") |
| |
|
| | |
| |
|
| | @staticmethod |
| | def _strip_prefix(k: str) -> str: |
| | return k[6:] if k.startswith("model.") else k |
| |
|
| | def _resolve(self, want: str): |
| | """ |
| | Return the actual stored key matching `want` by trying known prefixes. |
| | """ |
| | for pref in self.POSSIBLE_PREFIXES: |
| | k = pref + want |
| | if k in self._all_keys: |
| | return k |
| | return None |
| |
|
| | def has(self, want: str) -> bool: |
| | return self._resolve(want) is not None |
| |
|
| | def get(self, want: str) -> torch.Tensor: |
| | real_key = self._resolve(want) |
| | if real_key is None: |
| | raise KeyError(f"Missing key: {want}") |
| | return self._get_by_real_key(real_key).to("cpu") |
| |
|
| | @property |
| | def base_keys(self): |
| | |
| | return [self._strip_prefix(k) for k in self._all_keys] |
| |
|
| | |
| |
|
| | def _init_from_single_file(self, file_path: Path): |
| | self._mode = "single" |
| | self._file = file_path |
| | self._handle = safe_open(file_path, framework="pt", device="cpu") |
| | self._all_keys = list(self._handle.keys()) |
| |
|
| | def _get_by_real_key(real_key: str): |
| | return self._handle.get_tensor(real_key) |
| |
|
| | self._get_by_real_key = _get_by_real_key |
| |
|
| | def _init_from_dir(self, dpath: Path): |
| | index_json = dpath / "diffusion_pytorch_model.safetensors.index.json" |
| | if index_json.exists(): |
| | with open(index_json, "r", encoding="utf-8") as f: |
| | index = json.load(f) |
| | weight_map = index["weight_map"] |
| | self._mode = "sharded" |
| | self._dpath = dpath |
| | self._weight_map = {k: dpath / v for k, v in weight_map.items()} |
| | self._all_keys = list(self._weight_map.keys()) |
| | self._open_handles = {} |
| |
|
| | def _get_by_real_key(real_key: str): |
| | fpath = self._weight_map[real_key] |
| | h = self._open_handles.get(fpath) |
| | if h is None: |
| | h = safe_open(fpath, framework="pt", device="cpu") |
| | self._open_handles[fpath] = h |
| | return h.get_tensor(real_key) |
| |
|
| | self._get_by_real_key = _get_by_real_key |
| | return |
| |
|
| | |
| | files = sorted(dpath.glob("*.safetensors")) |
| | if len(files) != 1: |
| | raise FileNotFoundError( |
| | f"No index found and {dpath} does not contain exactly one .safetensors file." |
| | ) |
| | self._init_from_single_file(files[0]) |
| |
|
| |
|
| | def main(): |
| | args = parse_args() |
| | src = DiffusersSource(Path(args.diffusers_path)) |
| |
|
| | |
| | num_dual = 0 |
| | num_single = 0 |
| | for k in src.base_keys: |
| | if k.startswith("transformer_blocks."): |
| | try: |
| | i = int(k.split(".")[1]) |
| | num_dual = max(num_dual, i + 1) |
| | except Exception: |
| | pass |
| | elif k.startswith("single_transformer_blocks."): |
| | try: |
| | i = int(k.split(".")[1]) |
| | num_single = max(num_single, i + 1) |
| | except Exception: |
| | pass |
| | print(f"Found {num_dual} dual-stream blocks, {num_single} single-stream blocks") |
| |
|
| | |
| | def swap_scale_shift(vec: torch.Tensor) -> torch.Tensor: |
| | shift, scale = vec.chunk(2, dim=0) |
| | return torch.cat([scale, shift], dim=0) |
| |
|
| | orig = {} |
| |
|
| | |
| | for b in range(num_dual): |
| | prefix = f"transformer_blocks.{b}." |
| | for okey, dvals in DIFFUSERS_MAP.items(): |
| | if not okey.startswith("double_blocks."): |
| | continue |
| | dkeys = [prefix + v for v in dvals] |
| | if not all(src.has(k) for k in dkeys): |
| | continue |
| | if len(dkeys) == 1: |
| | orig[okey.replace("()", str(b))] = src.get(dkeys[0]) |
| | else: |
| | orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0) |
| |
|
| | |
| | for b in range(num_single): |
| | prefix = f"single_transformer_blocks.{b}." |
| | for okey, dvals in DIFFUSERS_MAP.items(): |
| | if not okey.startswith("single_blocks."): |
| | continue |
| | dkeys = [prefix + v for v in dvals] |
| | if not all(src.has(k) for k in dkeys): |
| | continue |
| | if len(dkeys) == 1: |
| | orig[okey.replace("()", str(b))] = src.get(dkeys[0]) |
| | else: |
| | orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0) |
| |
|
| | |
| | for okey, dvals in DIFFUSERS_MAP.items(): |
| | if okey.startswith(("double_blocks.", "single_blocks.")): |
| | continue |
| | dkeys = dvals |
| | if not all(src.has(k) for k in dkeys): |
| | continue |
| | if len(dkeys) == 1: |
| | orig[okey] = src.get(dkeys[0]) |
| | else: |
| | orig[okey] = torch.cat([src.get(k) for k in dkeys], dim=0) |
| |
|
| | |
| | if "final_layer.adaLN_modulation.1.weight" in orig: |
| | orig["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift( |
| | orig["final_layer.adaLN_modulation.1.weight"] |
| | ) |
| | if "final_layer.adaLN_modulation.1.bias" in orig: |
| | orig["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift( |
| | orig["final_layer.adaLN_modulation.1.bias"] |
| | ) |
| |
|
| | |
| | if args.fp8 or args.fp8_scaled: |
| | dtype = torch.float8_e4m3fn |
| | minv, maxv = torch.finfo(dtype).min, torch.finfo(dtype).max |
| |
|
| | def stochastic_round_to(t): |
| | t = t.float().clamp(minv, maxv) |
| | lower = torch.floor(t * 256) / 256 |
| | upper = torch.ceil(t * 256) / 256 |
| | prob = torch.where(upper != lower, (t - lower) / (upper - lower), torch.zeros_like(t)) |
| | rnd = torch.rand_like(t) |
| | out = torch.where(rnd < prob, upper, lower) |
| | return out.to(dtype) |
| |
|
| | def scale_to_8bit(weight, target_max=416.0): |
| | absmax = weight.abs().max() |
| | scale = absmax / target_max if absmax > 0 else torch.tensor(1.0) |
| | scaled = (weight / scale).clamp(minv, maxv).to(dtype) |
| | return scaled, scale |
| |
|
| | scales = {} |
| | for k in tqdm(list(orig.keys()), desc="Quantizing to fp8"): |
| | t = orig[k] |
| | if args.fp8: |
| | orig[k] = stochastic_round_to(t) |
| | else: |
| | if k.endswith(".weight") and t.dim() == 2: |
| | qt, s = scale_to_8bit(t) |
| | orig[k] = qt |
| | scales[k[:-len(".weight")] + ".scale_weight"] = s |
| | else: |
| | orig[k] = t.clamp(minv, maxv).to(dtype) |
| | if args.fp8_scaled: |
| | orig.update(scales) |
| | orig["scaled_fp8"] = torch.tensor([], dtype=dtype) |
| | else: |
| | |
| | for k in list(orig.keys()): |
| | orig[k] = orig[k].to(torch.bfloat16).cpu() |
| |
|
| | out_path = Path(args.output_path) |
| | out_path.parent.mkdir(parents=True, exist_ok=True) |
| | meta = OrderedDict() |
| | meta["format"] = "pt" |
| | meta["modelspec.date"] = __import__("datetime").date.today().strftime("%Y-%m-%d") |
| | print(f"Saving transformer to: {out_path}") |
| | safetensors.torch.save_file(orig, str(out_path), metadata=meta) |
| | print("Done.") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|