Buckets:

Mercity/FluxDistill / scripts /01_inspect_model.py
Pranav2748's picture
download
raw
2.41 kB
"""Introspect the klein-4B transformer: module tree, block containers, single-block
forward signature, and exact per-block param counts. Read real names before surgery."""
import inspect
import sys
import torch
import diffusers
print("diffusers", diffusers.__version__)
flux2_classes = [n for n in dir(diffusers) if "Flux2" in n]
print("Flux2 classes:", flux2_classes)
from diffusers import Flux2Transformer2DModel
LOCAL = "models/klein-4b/transformer"
print(f"\nloading transformer from {LOCAL} (cpu, bf16)...")
tf = Flux2Transformer2DModel.from_pretrained(LOCAL, torch_dtype=torch.bfloat16)
tf.eval()
def count(m):
return sum(p.numel() for p in m.parameters())
print("\n=== top-level children ===")
for name, child in tf.named_children():
print(f" {name:30s} {type(child).__name__:30s} params={count(child)/1e6:8.2f}M")
# Identify the two block containers (double / single) by ModuleList of largest length.
modulelists = [(n, c) for n, c in tf.named_children() if isinstance(c, torch.nn.ModuleList)]
print("\n=== ModuleList children (block containers) ===")
for n, c in modulelists:
print(f" {n:30s} len={len(c):3d} block_type={type(c[0]).__name__} params/block={count(c[0])/1e6:.2f}M")
# Single-block forward signature (the surrogate must match this).
print("\n=== block forward signatures ===")
for n, c in modulelists:
sig = inspect.signature(c[0].forward)
print(f" {n}[0].forward{sig}")
# Single-block internal structure.
print("\n=== single block submodules (first single block) ===")
single_name = None
for n, c in modulelists:
if "single" in n.lower():
single_name = n
for sn, sc in c[0].named_children():
print(f" {sn:24s} {type(sc).__name__:28s} params={count(sc)/1e6:.3f}M")
break
print("\n=== param totals ===")
total = count(tf)
print(f" transformer total: {total/1e9:.4f}B")
twod = sum(p.numel() for p in tf.parameters() if p.ndim == 2)
print(f" 2D params (Muon): {twod/1e9:.4f}B ({100*twod/total:.1f}%)")
print(f" non-2D (AdamW): {(total-twod)/1e9:.4f}B ({100*(total-twod)/total:.1f}%)")
# Param-group preview: which top-level modules are non-block (embedders/proj_out/norms).
print("\n=== non-block top-level params (AdamW candidates) ===")
for name, child in tf.named_children():
if not isinstance(child, torch.nn.ModuleList):
print(f" {name:30s} params={count(child)/1e6:8.3f}M")
print("\nDONE")

Xet Storage Details

Size:
2.41 kB
·
Xet hash:
3af8c0e368589ce78d8dfbaa6698d759b5305641b2fddd3824c1c3d8641faa62

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.