""" WAN IP-Adapter — zero-shot face conditioning via T5 cross-attention injection. Strategy ──────── Instead of patching WAN's self-attention blocks (which requires trained K/V projections that don't exist for WAN), we inject face identity through the cross-attention pathway that WAN already uses for text conditioning. Pipeline 1. SigLIP2 so400m (1152-d patch tokens) ↓ TimeResampler (SD3.5 trained weights, 8 queries → 1024-d) ↓ proj_face nn.Linear(1024 → 4096, xavier_uniform init) = 8 face tokens in T5 space (1, 8, 4096) 2. These are appended to the T5 prompt_embeds before each pipe() call. WAN's cross-attention naturally attends to all tokens in encoder_hidden_states, so no transformer surgery is needed. 3. For CFG (guidance_scale > 1), zeros are appended to the negative embeds so the unconditional branch is face-neutral, not anti-face. Why this works zero-shot ──────────────────────── The TimeResampler is trained (SD3.5 weights) and produces semantically structured 1024-d tokens. The random proj_face (xavier_uniform) is a fixed linear map — it preserves the relative geometry of the resampler space, so the same face always maps to the same region of T5 space and similar faces map to nearby regions. WAN's cross-attention sees consistent identity tokens for consistent faces. Usage in app.py ─────────────── Init (once, inside _init_pipeline): ip_adapter = WanIPAdapter(pipe, device=pipe.device, dtype=torch.bfloat16) Per-generation (inside run_inference, before pipe()): prompt_embeds, neg_embeds, prompt_mask, neg_mask = ip_adapter.encode_prompt( face_image=face_ref_image, # PIL Image or None prompt=effective_prompt, negative_prompt=negative_prompt, ip_scale=ip_scale, # 0.0 → 1.0 ) result = pipe( ... prompt_embeds=prompt_embeds, negative_prompt_embeds=neg_embeds, prompt_attention_mask=prompt_mask, negative_prompt_attention_mask=neg_mask, ) """ from __future__ import annotations import math from pathlib import Path from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from huggingface_hub import hf_hub_download from PIL import Image from transformers import AutoProcessor, SiglipVisionModel # ── Perceiver resampler (unchanged from original — SD3.5 weights load here) ─── class _FeedForward(nn.Module): def __init__(self, dim: int, mult: int = 4): super().__init__() self.net = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, dim * mult, bias=False), nn.GELU(), nn.Linear(dim * mult, dim, bias=False), ) def forward(self, x): return self.net(x) def _reshape(t: torch.Tensor, heads: int) -> torch.Tensor: b, n, d = t.shape return t.reshape(b, n, heads, d // heads).transpose(1, 2) class _PerceiverAttention(nn.Module): def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8): super().__init__() self.heads = heads inner = dim_head * heads self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.to_q = nn.Linear(dim, inner, bias=False) self.to_kv = nn.Linear(dim, inner * 2, bias=False) self.to_out = nn.Linear(inner, dim, bias=False) def forward(self, x: torch.Tensor, latents: torch.Tensor) -> torch.Tensor: x = self.norm1(x) latents = self.norm2(latents) q = _reshape(self.to_q(latents), self.heads) kv_in = torch.cat([x, latents], dim=1) k, v = self.to_kv(kv_in).chunk(2, dim=-1) k, v = _reshape(k, self.heads), _reshape(v, self.heads) out = F.scaled_dot_product_attention(q, k, v) out = out.transpose(1, 2).reshape(latents.shape[0], -1, self.to_out.in_features) return self.to_out(out) + latents class TimeResampler(nn.Module): """ Perceiver resampler — architecture matches image_proj.* in InstantX/SD3.5-Large-IP-Adapter ip-adapter.bin so weights load cleanly. Output: (batch, num_queries=8, output_dim=1024) """ def __init__( self, dim: int = 1024, depth: int = 8, dim_head: int = 64, heads: int = 16, num_queries: int = 8, embedding_dim: int = 1152, # SigLIP2 so400m hidden size output_dim: int = 1024, ff_mult: int = 4, timestep_in_dim: int = 320, timestep_flip_sin_to_cos: bool = True, timestep_freq_shift: int = 0, ): super().__init__() from diffusers.models.embeddings import Timesteps, TimestepEmbedding self.num_queries = num_queries self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5) self.proj_in = nn.Linear(embedding_dim, dim) self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift) self.t_emb = TimestepEmbedding(timestep_in_dim, dim) self.layers = nn.ModuleList([ nn.ModuleList([ _PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), _FeedForward(dim=dim, mult=ff_mult), nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim)), ]) for _ in range(depth) ]) self.proj_out = nn.Linear(dim, output_dim) self.norm_out = nn.LayerNorm(output_dim) def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: t = self.time_proj(timestep.flatten()).to(x.dtype) t_emb = self.t_emb(t) latents = self.latents.expand(x.size(0), -1, -1).clone() x = self.proj_in(x) for attn, ff, adaln in self.layers: s_msa, c_msa, s_mlp, c_mlp = adaln(t_emb).chunk(4, dim=-1) latents = latents * (1 + c_msa[:, None]) + s_msa[:, None] latents = attn(x, latents) latents = latents * (1 + c_mlp[:, None]) + s_mlp[:, None] latents = ff(latents) + latents return self.norm_out(self.proj_out(latents)) # (B, 8, 1024) # ── Main class ───────────────────────────────────────────────────────────────── class WanIPAdapter: """ Zero-shot face conditioning for WAN I2V via T5 cross-attention injection. No transformer patching. Face tokens are appended to prompt_embeds and WAN's existing cross-attention handles the rest. """ _IP_ADAPTER_REPO = "InstantX/SD3.5-Large-IP-Adapter" _IP_ADAPTER_FILE = "ip-adapter.bin" _VISION_MODEL = "google/siglip-so400m-patch14-384" # WAN transformer cross-attention dim (text_dim in WanTransformer3DModel) _T5_DIM = 4096 # TimeResampler output dim _RESAMPLER_DIM = 1024 # Number of face tokens appended to the T5 sequence _NUM_FACE_TOKENS = 8 def __init__( self, pipe, device: torch.device, dtype: torch.dtype = torch.bfloat16, cache_dir: str = "/data/ip_adapter", ): self.pipe = pipe self.device = device self.dtype = dtype self._load_vision_encoder() self._load_resampler(cache_dir) self._build_proj_face() print("[IP-Adapter] ready — T5-concat mode, no transformer patching") # ── setup ────────────────────────────────────────────────────────────────── def _load_vision_encoder(self): print("[IP-Adapter] loading SigLIP vision encoder…") self.vis_proc = AutoProcessor.from_pretrained(self._VISION_MODEL) self.vis_model = SiglipVisionModel.from_pretrained( self._VISION_MODEL, torch_dtype=self.dtype, ).to(self.device) self.vis_model.eval() print("[IP-Adapter] SigLIP loaded") def _load_resampler(self, cache_dir: str): print("[IP-Adapter] loading TimeResampler (SD3.5 ip-adapter.bin)…") ckpt = hf_hub_download( repo_id=self._IP_ADAPTER_REPO, filename=self._IP_ADAPTER_FILE, local_dir=cache_dir, ) state = torch.load(ckpt, map_location="cpu", weights_only=True) img_proj = { k[len("image_proj."):]: v for k, v in state.items() if k.startswith("image_proj.") } self.resampler = TimeResampler().to(self.device, self.dtype) missing, _ = self.resampler.load_state_dict(img_proj, strict=False) if missing: print(f"[IP-Adapter] resampler missing keys ({len(missing)}): {missing[:4]}…") self.resampler.eval() print("[IP-Adapter] resampler loaded") def _build_proj_face(self): """ Fixed linear projection: resampler output (1024) → T5 space (4096). Xavier-uniform init so face tokens land at reasonable magnitude relative to T5 embeddings. This projection is never trained — it's a fixed consistent mapping that preserves the resampler's relative geometry. """ self.proj_face = nn.Linear(self._RESAMPLER_DIM, self._T5_DIM, bias=False) nn.init.xavier_uniform_(self.proj_face.weight) self.proj_face = self.proj_face.to(self.device, self.dtype) self.proj_face.eval() n_params = self.proj_face.weight.numel() print(f"[IP-Adapter] proj_face built ({n_params:,} params, xavier_uniform, frozen)") # ── encoding ─────────────────────────────────────────────────────────────── @torch.no_grad() def _encode_face_tokens(self, image: Image.Image, timestep: int = 500) -> torch.Tensor: """ Encode *image* → (1, 8, 4096) face tokens in T5 space. The timestep passed to the TimeResampler controls which denoising stage the resampler "thinks" it's at. 500 (mid-point) is a reasonable default; lower values produce more detail-focused tokens. """ inputs = self.vis_proc(images=image, return_tensors="pt").to(self.device) vis_out = self.vis_model(**inputs) # Use patch tokens (last_hidden_state) rather than pooled for spatial detail vis_feats = vis_out.last_hidden_state.to(self.dtype) # (1, N, 1152) t = torch.tensor([timestep], device=self.device, dtype=torch.long) emb = self.resampler(vis_feats, t) # (1, 8, 1024) return self.proj_face(emb) # (1, 8, 4096) # ── main API ─────────────────────────────────────────────────────────────── def encode_prompt( self, face_image: Optional[Image.Image], prompt: str, negative_prompt: str = "", ip_scale: float = 0.6, num_videos_per_prompt: int = 1, do_classifier_free_guidance: bool = False, timestep: int = 500, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Returns (prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_attention_mask) ready to pass directly to pipe(). If *face_image* is None or *ip_scale* == 0, returns vanilla text embeds. ip_scale blends face tokens into the prompt by scaling them before concat. Scale of 1.0 = full face signal; 0.5 = half strength. """ # ── text embeddings ──────────────────────────────────────────────────── ( prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_attention_mask, ) = self.pipe.encode_prompt( prompt=prompt, negative_prompt=negative_prompt if do_classifier_free_guidance else None, do_classifier_free_guidance=do_classifier_free_guidance, num_videos_per_prompt=num_videos_per_prompt, device=self.device, ) if face_image is None or ip_scale == 0.0: return ( prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_attention_mask, ) # ── face tokens ──────────────────────────────────────────────────────── face_tokens = self._encode_face_tokens(face_image, timestep=timestep) # Repeat for batch if needed B = prompt_embeds.shape[0] if B > 1: face_tokens = face_tokens.expand(B, -1, -1) # Scale face tokens — controls identity signal strength face_tokens = face_tokens * ip_scale # Append to prompt embeds prompt_embeds = torch.cat([prompt_embeds, face_tokens], dim=1) face_ones = torch.ones(B, self._NUM_FACE_TOKENS, device=self.device, dtype=prompt_attention_mask.dtype) prompt_attention_mask = torch.cat([prompt_attention_mask, face_ones], dim=1) # For negative: append zeros (face-neutral, not anti-face) if negative_prompt_embeds is not None: B_neg = negative_prompt_embeds.shape[0] neg_face = torch.zeros(B_neg, self._NUM_FACE_TOKENS, self._T5_DIM, device=self.device, dtype=self.dtype) negative_prompt_embeds = torch.cat([negative_prompt_embeds, neg_face], dim=1) neg_ones = torch.ones(B_neg, self._NUM_FACE_TOKENS, device=self.device, dtype=negative_attention_mask.dtype) negative_attention_mask = torch.cat([negative_attention_mask, neg_ones], dim=1) print(f"[IP-Adapter] face tokens appended — " f"prompt_embeds: {prompt_embeds.shape}, scale={ip_scale:.2f}") return ( prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_attention_mask, )