Spaces:
Runtime error
Runtime error
| """ | |
| WAN IP-Adapter β zero-shot face conditioning via T5 cross-attention injection. | |
| Strategy | |
| ββββββββ | |
| Instead of patching WAN's self-attention blocks (which requires trained K/V | |
| projections that don't exist for WAN), we inject face identity through the | |
| cross-attention pathway that WAN already uses for text conditioning. | |
| Pipeline | |
| 1. SigLIP2 so400m (1152-d patch tokens) | |
| β TimeResampler (SD3.5 trained weights, 8 queries β 1024-d) | |
| β proj_face nn.Linear(1024 β 4096, xavier_uniform init) | |
| = 8 face tokens in T5 space (1, 8, 4096) | |
| 2. These are appended to the T5 prompt_embeds before each pipe() call. | |
| WAN's cross-attention naturally attends to all tokens in encoder_hidden_states, | |
| so no transformer surgery is needed. | |
| 3. For CFG (guidance_scale > 1), zeros are appended to the negative embeds | |
| so the unconditional branch is face-neutral, not anti-face. | |
| Why this works zero-shot | |
| ββββββββββββββββββββββββ | |
| The TimeResampler is trained (SD3.5 weights) and produces semantically | |
| structured 1024-d tokens. The random proj_face (xavier_uniform) is a | |
| fixed linear map β it preserves the relative geometry of the resampler | |
| space, so the same face always maps to the same region of T5 space and | |
| similar faces map to nearby regions. WAN's cross-attention sees consistent | |
| identity tokens for consistent faces. | |
| Usage in app.py | |
| βββββββββββββββ | |
| Init (once, inside _init_pipeline): | |
| ip_adapter = WanIPAdapter(pipe, device=pipe.device, dtype=torch.bfloat16) | |
| Per-generation (inside run_inference, before pipe()): | |
| prompt_embeds, neg_embeds, prompt_mask, neg_mask = ip_adapter.encode_prompt( | |
| face_image=face_ref_image, # PIL Image or None | |
| prompt=effective_prompt, | |
| negative_prompt=negative_prompt, | |
| ip_scale=ip_scale, # 0.0 β 1.0 | |
| ) | |
| result = pipe( | |
| ... | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=neg_embeds, | |
| prompt_attention_mask=prompt_mask, | |
| negative_prompt_attention_mask=neg_mask, | |
| ) | |
| """ | |
| from __future__ import annotations | |
| import math | |
| from pathlib import Path | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| from transformers import AutoProcessor, SiglipVisionModel | |
| # ββ Perceiver resampler (unchanged from original β SD3.5 weights load here) βββ | |
| class _FeedForward(nn.Module): | |
| def __init__(self, dim: int, mult: int = 4): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.LayerNorm(dim), | |
| nn.Linear(dim, dim * mult, bias=False), | |
| nn.GELU(), | |
| nn.Linear(dim * mult, dim, bias=False), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| def _reshape(t: torch.Tensor, heads: int) -> torch.Tensor: | |
| b, n, d = t.shape | |
| return t.reshape(b, n, heads, d // heads).transpose(1, 2) | |
| class _PerceiverAttention(nn.Module): | |
| def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8): | |
| super().__init__() | |
| self.heads = heads | |
| inner = dim_head * heads | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.norm2 = nn.LayerNorm(dim) | |
| self.to_q = nn.Linear(dim, inner, bias=False) | |
| self.to_kv = nn.Linear(dim, inner * 2, bias=False) | |
| self.to_out = nn.Linear(inner, dim, bias=False) | |
| def forward(self, x: torch.Tensor, latents: torch.Tensor) -> torch.Tensor: | |
| x = self.norm1(x) | |
| latents = self.norm2(latents) | |
| q = _reshape(self.to_q(latents), self.heads) | |
| kv_in = torch.cat([x, latents], dim=1) | |
| k, v = self.to_kv(kv_in).chunk(2, dim=-1) | |
| k, v = _reshape(k, self.heads), _reshape(v, self.heads) | |
| out = F.scaled_dot_product_attention(q, k, v) | |
| out = out.transpose(1, 2).reshape(latents.shape[0], -1, self.to_out.in_features) | |
| return self.to_out(out) + latents | |
| class TimeResampler(nn.Module): | |
| """ | |
| Perceiver resampler β architecture matches image_proj.* in | |
| InstantX/SD3.5-Large-IP-Adapter ip-adapter.bin so weights load cleanly. | |
| Output: (batch, num_queries=8, output_dim=1024) | |
| """ | |
| def __init__( | |
| self, | |
| dim: int = 1024, | |
| depth: int = 8, | |
| dim_head: int = 64, | |
| heads: int = 16, | |
| num_queries: int = 8, | |
| embedding_dim: int = 1152, # SigLIP2 so400m hidden size | |
| output_dim: int = 1024, | |
| ff_mult: int = 4, | |
| timestep_in_dim: int = 320, | |
| timestep_flip_sin_to_cos: bool = True, | |
| timestep_freq_shift: int = 0, | |
| ): | |
| super().__init__() | |
| from diffusers.models.embeddings import Timesteps, TimestepEmbedding | |
| self.num_queries = num_queries | |
| self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5) | |
| self.proj_in = nn.Linear(embedding_dim, dim) | |
| self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift) | |
| self.t_emb = TimestepEmbedding(timestep_in_dim, dim) | |
| self.layers = nn.ModuleList([ | |
| nn.ModuleList([ | |
| _PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), | |
| _FeedForward(dim=dim, mult=ff_mult), | |
| nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim)), | |
| ]) | |
| for _ in range(depth) | |
| ]) | |
| self.proj_out = nn.Linear(dim, output_dim) | |
| self.norm_out = nn.LayerNorm(output_dim) | |
| def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: | |
| t = self.time_proj(timestep.flatten()).to(x.dtype) | |
| t_emb = self.t_emb(t) | |
| latents = self.latents.expand(x.size(0), -1, -1).clone() | |
| x = self.proj_in(x) | |
| for attn, ff, adaln in self.layers: | |
| s_msa, c_msa, s_mlp, c_mlp = adaln(t_emb).chunk(4, dim=-1) | |
| latents = latents * (1 + c_msa[:, None]) + s_msa[:, None] | |
| latents = attn(x, latents) | |
| latents = latents * (1 + c_mlp[:, None]) + s_mlp[:, None] | |
| latents = ff(latents) + latents | |
| return self.norm_out(self.proj_out(latents)) # (B, 8, 1024) | |
| # ββ Main class βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class WanIPAdapter: | |
| """ | |
| Zero-shot face conditioning for WAN I2V via T5 cross-attention injection. | |
| No transformer patching. Face tokens are appended to prompt_embeds and | |
| WAN's existing cross-attention handles the rest. | |
| """ | |
| _IP_ADAPTER_REPO = "InstantX/SD3.5-Large-IP-Adapter" | |
| _IP_ADAPTER_FILE = "ip-adapter.bin" | |
| _VISION_MODEL = "google/siglip-so400m-patch14-384" | |
| # WAN transformer cross-attention dim (text_dim in WanTransformer3DModel) | |
| _T5_DIM = 4096 | |
| # TimeResampler output dim | |
| _RESAMPLER_DIM = 1024 | |
| # Number of face tokens appended to the T5 sequence | |
| _NUM_FACE_TOKENS = 8 | |
| def __init__( | |
| self, | |
| pipe, | |
| device: torch.device, | |
| dtype: torch.dtype = torch.bfloat16, | |
| cache_dir: str = "/data/ip_adapter", | |
| ): | |
| self.pipe = pipe | |
| self.device = device | |
| self.dtype = dtype | |
| self._load_vision_encoder() | |
| self._load_resampler(cache_dir) | |
| self._build_proj_face() | |
| print("[IP-Adapter] ready β T5-concat mode, no transformer patching") | |
| # ββ setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_vision_encoder(self): | |
| print("[IP-Adapter] loading SigLIP vision encoderβ¦") | |
| self.vis_proc = AutoProcessor.from_pretrained(self._VISION_MODEL) | |
| self.vis_model = SiglipVisionModel.from_pretrained( | |
| self._VISION_MODEL, torch_dtype=self.dtype, | |
| ).to(self.device) | |
| self.vis_model.eval() | |
| print("[IP-Adapter] SigLIP loaded") | |
| def _load_resampler(self, cache_dir: str): | |
| print("[IP-Adapter] loading TimeResampler (SD3.5 ip-adapter.bin)β¦") | |
| ckpt = hf_hub_download( | |
| repo_id=self._IP_ADAPTER_REPO, | |
| filename=self._IP_ADAPTER_FILE, | |
| local_dir=cache_dir, | |
| ) | |
| state = torch.load(ckpt, map_location="cpu", weights_only=True) | |
| img_proj = { | |
| k[len("image_proj."):]: v | |
| for k, v in state.items() | |
| if k.startswith("image_proj.") | |
| } | |
| self.resampler = TimeResampler().to(self.device, self.dtype) | |
| missing, _ = self.resampler.load_state_dict(img_proj, strict=False) | |
| if missing: | |
| print(f"[IP-Adapter] resampler missing keys ({len(missing)}): {missing[:4]}β¦") | |
| self.resampler.eval() | |
| print("[IP-Adapter] resampler loaded") | |
| def _build_proj_face(self): | |
| """ | |
| Fixed linear projection: resampler output (1024) β T5 space (4096). | |
| Xavier-uniform init so face tokens land at reasonable magnitude relative | |
| to T5 embeddings. This projection is never trained β it's a fixed | |
| consistent mapping that preserves the resampler's relative geometry. | |
| """ | |
| self.proj_face = nn.Linear(self._RESAMPLER_DIM, self._T5_DIM, bias=False) | |
| nn.init.xavier_uniform_(self.proj_face.weight) | |
| self.proj_face = self.proj_face.to(self.device, self.dtype) | |
| self.proj_face.eval() | |
| n_params = self.proj_face.weight.numel() | |
| print(f"[IP-Adapter] proj_face built ({n_params:,} params, xavier_uniform, frozen)") | |
| # ββ encoding βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _encode_face_tokens(self, image: Image.Image, timestep: int = 500) -> torch.Tensor: | |
| """ | |
| Encode *image* β (1, 8, 4096) face tokens in T5 space. | |
| The timestep passed to the TimeResampler controls which denoising | |
| stage the resampler "thinks" it's at. 500 (mid-point) is a reasonable | |
| default; lower values produce more detail-focused tokens. | |
| """ | |
| inputs = self.vis_proc(images=image, return_tensors="pt").to(self.device) | |
| vis_out = self.vis_model(**inputs) | |
| # Use patch tokens (last_hidden_state) rather than pooled for spatial detail | |
| vis_feats = vis_out.last_hidden_state.to(self.dtype) # (1, N, 1152) | |
| t = torch.tensor([timestep], device=self.device, dtype=torch.long) | |
| emb = self.resampler(vis_feats, t) # (1, 8, 1024) | |
| return self.proj_face(emb) # (1, 8, 4096) | |
| # ββ main API βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def encode_prompt( | |
| self, | |
| face_image: Optional[Image.Image], | |
| prompt: str, | |
| negative_prompt: str = "", | |
| ip_scale: float = 0.6, | |
| num_videos_per_prompt: int = 1, | |
| do_classifier_free_guidance: bool = False, | |
| timestep: int = 500, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Returns (prompt_embeds, negative_prompt_embeds, | |
| prompt_attention_mask, negative_attention_mask) | |
| ready to pass directly to pipe(). | |
| If *face_image* is None or *ip_scale* == 0, returns vanilla text embeds. | |
| ip_scale blends face tokens into the prompt by scaling them before concat. | |
| Scale of 1.0 = full face signal; 0.5 = half strength. | |
| """ | |
| # ββ text embeddings ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ( | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| prompt_attention_mask, | |
| negative_attention_mask, | |
| ) = self.pipe.encode_prompt( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt if do_classifier_free_guidance else None, | |
| do_classifier_free_guidance=do_classifier_free_guidance, | |
| num_videos_per_prompt=num_videos_per_prompt, | |
| device=self.device, | |
| ) | |
| if face_image is None or ip_scale == 0.0: | |
| return ( | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| prompt_attention_mask, | |
| negative_attention_mask, | |
| ) | |
| # ββ face tokens ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| face_tokens = self._encode_face_tokens(face_image, timestep=timestep) | |
| # Repeat for batch if needed | |
| B = prompt_embeds.shape[0] | |
| if B > 1: | |
| face_tokens = face_tokens.expand(B, -1, -1) | |
| # Scale face tokens β controls identity signal strength | |
| face_tokens = face_tokens * ip_scale | |
| # Append to prompt embeds | |
| prompt_embeds = torch.cat([prompt_embeds, face_tokens], dim=1) | |
| face_ones = torch.ones(B, self._NUM_FACE_TOKENS, device=self.device, | |
| dtype=prompt_attention_mask.dtype) | |
| prompt_attention_mask = torch.cat([prompt_attention_mask, face_ones], dim=1) | |
| # For negative: append zeros (face-neutral, not anti-face) | |
| if negative_prompt_embeds is not None: | |
| B_neg = negative_prompt_embeds.shape[0] | |
| neg_face = torch.zeros(B_neg, self._NUM_FACE_TOKENS, self._T5_DIM, | |
| device=self.device, dtype=self.dtype) | |
| negative_prompt_embeds = torch.cat([negative_prompt_embeds, neg_face], dim=1) | |
| neg_ones = torch.ones(B_neg, self._NUM_FACE_TOKENS, device=self.device, | |
| dtype=negative_attention_mask.dtype) | |
| negative_attention_mask = torch.cat([negative_attention_mask, neg_ones], dim=1) | |
| print(f"[IP-Adapter] face tokens appended β " | |
| f"prompt_embeds: {prompt_embeds.shape}, scale={ip_scale:.2f}") | |
| return ( | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| prompt_attention_mask, | |
| negative_attention_mask, | |
| ) |