# ============================================================================ # TinyFlux-Deep Inference Cell - With ExpertPredictor # ============================================================================ # Run the model cell before this one (defines TinyFluxDeep, TinyFluxDeepConfig) # Loads from: AbstractPhil/tiny-flux-deep or local checkpoint # # The ExpertPredictor runs standalone at inference - no SD1.5-flow needed. # It predicts timestep expertise from (time_emb, clip_pooled). # ============================================================================ import torch from huggingface_hub import hf_hub_download from safetensors.torch import load_file from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL from PIL import Image import numpy as np import os # ============================================================================ # CONFIG # ============================================================================ DEVICE = "cuda" DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # Model loading HF_REPO = "AbstractPhil/tiny-flux-deep" # stable v3 step_316875 LOAD_FROM = "hub:step_346875" # "hub", "hub:step_XXXXX", "hub:step_XXXXX_ema", "local:/path/to/weights.safetensors" # Generation settings NUM_STEPS = 50 GUIDANCE_SCALE = 5.0 # Note: this is now just for CFG, not the broken guidance_in HEIGHT = 512 WIDTH = 512 SEED = None SHIFT = 3.0 # Model architecture (must match training) USE_EXPERT_PREDICTOR = True EXPERT_DIM = 1280 EXPERT_HIDDEN_DIM = 512 # ============================================================================ # LOAD TEXT ENCODERS # ============================================================================ print("Loading text encoders...") t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=DTYPE).to(DEVICE).eval() clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval() # ============================================================================ # LOAD VAE # ============================================================================ print("Loading Flux VAE...") vae = AutoencoderKL.from_pretrained( "black-forest-labs/FLUX.1-schnell", subfolder="vae", torch_dtype=DTYPE ).to(DEVICE).eval() # ============================================================================ # LOAD TINYFLUX-DEEP MODEL # ============================================================================ print(f"Loading TinyFlux-Deep from: {LOAD_FROM}") # Config with ExpertPredictor (no guidance_embeds) config = TinyFluxDeepConfig( use_expert_predictor=USE_EXPERT_PREDICTOR, expert_dim=EXPERT_DIM, expert_hidden_dim=EXPERT_HIDDEN_DIM, guidance_embeds=False, # Replaced by expert_predictor ) model = TinyFluxDeep(config).to(DEVICE).to(DTYPE) # Keys to handle during loading DEPRECATED_KEYS = { 'time_in.sin_basis', 'guidance_in.sin_basis', 'guidance_in.mlp.0.weight', 'guidance_in.mlp.0.bias', 'guidance_in.mlp.2.weight', 'guidance_in.mlp.2.bias', } def load_weights(path): """Load weights from .safetensors or .pt file.""" if path.endswith(".safetensors"): state_dict = load_file(path) elif path.endswith(".pt"): ckpt = torch.load(path, map_location=DEVICE, weights_only=False) if isinstance(ckpt, dict): if "model" in ckpt: state_dict = ckpt["model"] elif "state_dict" in ckpt: state_dict = ckpt["state_dict"] else: state_dict = ckpt else: state_dict = ckpt else: try: state_dict = load_file(path) except: state_dict = torch.load(path, map_location=DEVICE, weights_only=False) # Strip "_orig_mod." prefix from keys (added by torch.compile) if any(k.startswith("_orig_mod.") for k in state_dict.keys()): print(" Stripping torch.compile prefix...") state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} return state_dict def load_model_weights(model, weights, source_name): """Load weights with architecture upgrade support.""" model_state = model.state_dict() loaded = [] skipped_deprecated = [] skipped_shape = [] missing_new = [] # Load matching weights for k, v in weights.items(): if k in DEPRECATED_KEYS or k.startswith('guidance_in.'): skipped_deprecated.append(k) elif k in model_state: if v.shape == model_state[k].shape: model_state[k] = v loaded.append(k) else: skipped_shape.append((k, v.shape, model_state[k].shape)) else: # Key not in model (maybe old architecture) skipped_deprecated.append(k) # Find new keys not in checkpoint for k in model_state: if k not in weights and not any(k.startswith(d.split('.')[0]) for d in DEPRECATED_KEYS if '.' in d): missing_new.append(k) # Apply loaded weights model.load_state_dict(model_state, strict=False) # Report print(f" ✓ Loaded: {len(loaded)} weights") if skipped_deprecated: print(f" ✓ Skipped deprecated: {len(skipped_deprecated)} (guidance_in, etc)") if skipped_shape: print(f" ⚠ Shape mismatch: {len(skipped_shape)}") for k, old, new in skipped_shape[:3]: print(f" {k}: {old} vs {new}") if missing_new: # Group by module modules = set(k.split('.')[0] for k in missing_new) print(f" ℹ New modules (fresh init): {modules}") print(f"✓ Loaded from {source_name}") if LOAD_FROM == "hub": try: weights_path = hf_hub_download(repo_id=HF_REPO, filename="model.safetensors") except: weights_path = hf_hub_download(repo_id=HF_REPO, filename="model.pt") weights = load_weights(weights_path) load_model_weights(model, weights, HF_REPO) elif LOAD_FROM.startswith("hub:"): ckpt_name = LOAD_FROM[4:] for ext in [".safetensors", ".pt", ""]: try: if ckpt_name.endswith((".safetensors", ".pt")): filename = ckpt_name if "/" in ckpt_name else f"checkpoints/{ckpt_name}" else: filename = f"checkpoints/{ckpt_name}{ext}" weights_path = hf_hub_download(repo_id=HF_REPO, filename=filename) weights = load_weights(weights_path) load_model_weights(model, weights, f"{HF_REPO}/{filename}") break except Exception as e: continue else: raise ValueError(f"Could not find checkpoint: {ckpt_name}") elif LOAD_FROM.startswith("local:"): weights_path = LOAD_FROM[6:] weights = load_weights(weights_path) load_model_weights(model, weights, weights_path) else: raise ValueError(f"Unknown LOAD_FROM: {LOAD_FROM}") model.eval() # Count parameters total_params = sum(p.numel() for p in model.parameters()) expert_params = sum(p.numel() for p in model.expert_predictor.parameters()) if model.expert_predictor else 0 print(f"Model params: {total_params:,} (expert_predictor: {expert_params:,})") # ============================================================================ # ENCODING FUNCTIONS # ============================================================================ @torch.inference_mode() def encode_prompt(prompt: str, max_length: int = 128): """Encode prompt with flan-t5-base and CLIP-L.""" t5_in = t5_tok( prompt, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" ).to(DEVICE) t5_out = t5_enc( input_ids=t5_in.input_ids, attention_mask=t5_in.attention_mask ).last_hidden_state clip_in = clip_tok( prompt, max_length=77, padding="max_length", truncation=True, return_tensors="pt" ).to(DEVICE) clip_out = clip_enc( input_ids=clip_in.input_ids, attention_mask=clip_in.attention_mask ) clip_pooled = clip_out.pooler_output return t5_out.to(DTYPE), clip_pooled.to(DTYPE) # ============================================================================ # FLOW MATCHING HELPERS # ============================================================================ def flux_shift(t, s=SHIFT): """Flux timestep shift - biases towards higher t (closer to data).""" return s * t / (1 + (s - 1) * t) # ============================================================================ # EULER DISCRETE FLOW MATCHING SAMPLER # ============================================================================ @torch.inference_mode() def euler_sample( model, prompt: str, negative_prompt: str = "", num_steps: int = 28, guidance_scale: float = 3.5, height: int = 512, width: int = 512, seed: int = None, ): """ Euler discrete sampler for rectified flow matching. Flow Matching formulation: x_t = (1 - t) * noise + t * data At t=0: noise, At t=1: data Velocity v = data - noise (constant) Sampling: Integrate from t=0 (noise) to t=1 (data) With ExpertPredictor: - No guidance embedding needed - Expert predictor runs internally from (time_emb, clip_pooled) - CFG still works via positive/negative prompt difference """ if seed is not None: torch.manual_seed(seed) generator = torch.Generator(device=DEVICE).manual_seed(seed) else: generator = None H_lat = height // 8 W_lat = width // 8 C_lat = 16 # Encode prompts t5_cond, clip_cond = encode_prompt(prompt) if guidance_scale > 1.0 and negative_prompt is not None: t5_uncond, clip_uncond = encode_prompt(negative_prompt) else: t5_uncond, clip_uncond = None, None # Start from pure noise (t=0) x = torch.randn(1, H_lat * W_lat, C_lat, device=DEVICE, dtype=DTYPE, generator=generator) # Create image position IDs img_ids = TinyFluxDeep.create_img_ids(1, H_lat, W_lat, DEVICE) # Timesteps: 0 → 1 with flux shift t_linear = torch.linspace(0, 1, num_steps + 1, device=DEVICE, dtype=DTYPE) timesteps = flux_shift(t_linear, s=SHIFT) print(f"Sampling with {num_steps} Euler steps (t: 0→1, shifted)...") for i in range(num_steps): t_curr = timesteps[i] t_next = timesteps[i + 1] dt = t_next - t_curr t_batch = t_curr.unsqueeze(0) # Predict velocity (no guidance embedding, expert_predictor runs internally) v_cond = model( hidden_states=x, encoder_hidden_states=t5_cond, pooled_projections=clip_cond, timestep=t_batch, img_ids=img_ids, # No guidance parameter - ExpertPredictor handles timestep awareness # No expert_features - predictor runs standalone at inference ) # Classifier-free guidance (true CFG via prompt difference) if guidance_scale > 1.0 and t5_uncond is not None: v_uncond = model( hidden_states=x, encoder_hidden_states=t5_uncond, pooled_projections=clip_uncond, timestep=t_batch, img_ids=img_ids, ) v = v_uncond + guidance_scale * (v_cond - v_uncond) else: v = v_cond # Euler step: x_{t+dt} = x_t + v * dt x = x + v * dt if (i + 1) % max(1, num_steps // 5) == 0 or i == num_steps - 1: print(f" Step {i+1}/{num_steps}, t={t_next.item():.3f}") # Reshape: (1, H*W, C) -> (1, C, H, W) latents = x.reshape(1, H_lat, W_lat, C_lat).permute(0, 3, 1, 2) return latents # ============================================================================ # DECODE LATENTS TO IMAGE # ============================================================================ @torch.inference_mode() def decode_latents(latents): """Decode VAE latents to PIL Image.""" latents = latents / vae.config.scaling_factor image = vae.decode(latents.to(vae.dtype)).sample image = (image / 2 + 0.5).clamp(0, 1) image = image[0].float().permute(1, 2, 0).cpu().numpy() image = (image * 255).astype(np.uint8) return Image.fromarray(image) # ============================================================================ # MAIN GENERATION FUNCTION # ============================================================================ def generate( prompt: str, negative_prompt: str = "", num_steps: int = NUM_STEPS, guidance_scale: float = GUIDANCE_SCALE, height: int = HEIGHT, width: int = WIDTH, seed: int = SEED, save_path: str = None, ): """ Generate an image from a text prompt. Args: prompt: Text description of desired image negative_prompt: What to avoid (empty string for none) num_steps: Number of Euler steps (20-50 recommended) guidance_scale: CFG scale (1.0=none, 3-7 typical) height: Output height in pixels (divisible by 8) width: Output width in pixels (divisible by 8) seed: Random seed (None for random) save_path: Path to save image (None to skip) Returns: PIL.Image """ print(f"\nGenerating: '{prompt}'") print(f"Settings: {num_steps} steps, cfg={guidance_scale}, {width}x{height}, seed={seed}") latents = euler_sample( model=model, prompt=prompt, negative_prompt=negative_prompt, num_steps=num_steps, guidance_scale=guidance_scale, height=height, width=width, seed=seed, ) print("Decoding latents...") image = decode_latents(latents) if save_path: image.save(save_path) print(f"✓ Saved to {save_path}") print("✓ Done!") return image # ============================================================================ # BATCH GENERATION # ============================================================================ def generate_batch( prompts: list, negative_prompt: str = "", num_steps: int = NUM_STEPS, guidance_scale: float = GUIDANCE_SCALE, height: int = HEIGHT, width: int = WIDTH, seed: int = SEED, output_dir: str = "./outputs", ): """Generate multiple images.""" os.makedirs(output_dir, exist_ok=True) images = [] for i, prompt in enumerate(prompts): img_seed = seed + i if seed is not None else None image = generate( prompt=prompt, negative_prompt=negative_prompt, num_steps=num_steps, guidance_scale=guidance_scale, height=height, width=width, seed=img_seed, save_path=os.path.join(output_dir, f"{i:03d}.png"), ) images.append(image) return images # ============================================================================ # COMPARISON FUNCTION (old vs new model) # ============================================================================ def compare_with_without_expert( prompt: str, negative_prompt: str = "", num_steps: int = 30, guidance_scale: float = 5.0, seed: int = 42, save_prefix: str = "compare", ): """ Generate same prompt with expert_predictor enabled vs disabled. Useful for A/B testing the effect of the distilled expert. """ # With expert image_with = generate( prompt=prompt, negative_prompt=negative_prompt, num_steps=num_steps, guidance_scale=guidance_scale, seed=seed, save_path=f"{save_prefix}_with_expert.png", ) # Without expert (temporarily disable) old_predictor = model.expert_predictor model.expert_predictor = None image_without = generate( prompt=prompt, negative_prompt=negative_prompt, num_steps=num_steps, guidance_scale=guidance_scale, seed=seed, save_path=f"{save_prefix}_without_expert.png", ) # Restore model.expert_predictor = old_predictor # Side by side combined = Image.new('RGB', (image_with.width * 2, image_with.height)) combined.paste(image_without, (0, 0)) combined.paste(image_with, (image_with.width, 0)) combined.save(f"{save_prefix}_comparison.png") print(f"\n✓ Comparison saved: {save_prefix}_comparison.png") print(f" Left: without expert | Right: with expert") return image_without, image_with, combined # ============================================================================ # QUICK TEST # ============================================================================ print("\n" + "="*60) print("TinyFlux-Deep + ExpertPredictor Inference Ready!") print("="*60) print(f"Config: {config.hidden_size} hidden, {config.num_attention_heads} heads") print(f" {config.num_double_layers} double, {config.num_single_layers} single layers") print(f" ExpertPredictor: {config.use_expert_predictor} (dim={config.expert_dim})") print(f"Total: {total_params:,} parameters") # Example usage: image = generate( prompt="subject, animal, feline, lion, natural habitat", negative_prompt="", num_steps=50, guidance_scale=5.0, seed=4545, width=512, height=512, ) image