Buckets:
| """Introspect the klein-4B transformer: module tree, block containers, single-block | |
| forward signature, and exact per-block param counts. Read real names before surgery.""" | |
| import inspect | |
| import sys | |
| import torch | |
| import diffusers | |
| print("diffusers", diffusers.__version__) | |
| flux2_classes = [n for n in dir(diffusers) if "Flux2" in n] | |
| print("Flux2 classes:", flux2_classes) | |
| from diffusers import Flux2Transformer2DModel | |
| LOCAL = "models/klein-4b/transformer" | |
| print(f"\nloading transformer from {LOCAL} (cpu, bf16)...") | |
| tf = Flux2Transformer2DModel.from_pretrained(LOCAL, torch_dtype=torch.bfloat16) | |
| tf.eval() | |
| def count(m): | |
| return sum(p.numel() for p in m.parameters()) | |
| print("\n=== top-level children ===") | |
| for name, child in tf.named_children(): | |
| print(f" {name:30s} {type(child).__name__:30s} params={count(child)/1e6:8.2f}M") | |
| # Identify the two block containers (double / single) by ModuleList of largest length. | |
| modulelists = [(n, c) for n, c in tf.named_children() if isinstance(c, torch.nn.ModuleList)] | |
| print("\n=== ModuleList children (block containers) ===") | |
| for n, c in modulelists: | |
| print(f" {n:30s} len={len(c):3d} block_type={type(c[0]).__name__} params/block={count(c[0])/1e6:.2f}M") | |
| # Single-block forward signature (the surrogate must match this). | |
| print("\n=== block forward signatures ===") | |
| for n, c in modulelists: | |
| sig = inspect.signature(c[0].forward) | |
| print(f" {n}[0].forward{sig}") | |
| # Single-block internal structure. | |
| print("\n=== single block submodules (first single block) ===") | |
| single_name = None | |
| for n, c in modulelists: | |
| if "single" in n.lower(): | |
| single_name = n | |
| for sn, sc in c[0].named_children(): | |
| print(f" {sn:24s} {type(sc).__name__:28s} params={count(sc)/1e6:.3f}M") | |
| break | |
| print("\n=== param totals ===") | |
| total = count(tf) | |
| print(f" transformer total: {total/1e9:.4f}B") | |
| twod = sum(p.numel() for p in tf.parameters() if p.ndim == 2) | |
| print(f" 2D params (Muon): {twod/1e9:.4f}B ({100*twod/total:.1f}%)") | |
| print(f" non-2D (AdamW): {(total-twod)/1e9:.4f}B ({100*(total-twod)/total:.1f}%)") | |
| # Param-group preview: which top-level modules are non-block (embedders/proj_out/norms). | |
| print("\n=== non-block top-level params (AdamW candidates) ===") | |
| for name, child in tf.named_children(): | |
| if not isinstance(child, torch.nn.ModuleList): | |
| print(f" {name:30s} params={count(child)/1e6:8.3f}M") | |
| print("\nDONE") | |
Xet Storage Details
- Size:
- 2.41 kB
- Xet hash:
- 3af8c0e368589ce78d8dfbaa6698d759b5305641b2fddd3824c1c3d8641faa62
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.