ZeroWan2GP / ip_adapter.py
Daankular's picture
Update ip_adapter.py
395e472 verified
"""
WAN IP-Adapter β€” zero-shot face conditioning via T5 cross-attention injection.
Strategy
────────
Instead of patching WAN's self-attention blocks (which requires trained K/V
projections that don't exist for WAN), we inject face identity through the
cross-attention pathway that WAN already uses for text conditioning.
Pipeline
1. SigLIP2 so400m (1152-d patch tokens)
↓ TimeResampler (SD3.5 trained weights, 8 queries β†’ 1024-d)
↓ proj_face nn.Linear(1024 β†’ 4096, xavier_uniform init)
= 8 face tokens in T5 space (1, 8, 4096)
2. These are appended to the T5 prompt_embeds before each pipe() call.
WAN's cross-attention naturally attends to all tokens in encoder_hidden_states,
so no transformer surgery is needed.
3. For CFG (guidance_scale > 1), zeros are appended to the negative embeds
so the unconditional branch is face-neutral, not anti-face.
Why this works zero-shot
────────────────────────
The TimeResampler is trained (SD3.5 weights) and produces semantically
structured 1024-d tokens. The random proj_face (xavier_uniform) is a
fixed linear map β€” it preserves the relative geometry of the resampler
space, so the same face always maps to the same region of T5 space and
similar faces map to nearby regions. WAN's cross-attention sees consistent
identity tokens for consistent faces.
Usage in app.py
───────────────
Init (once, inside _init_pipeline):
ip_adapter = WanIPAdapter(pipe, device=pipe.device, dtype=torch.bfloat16)
Per-generation (inside run_inference, before pipe()):
prompt_embeds, neg_embeds, prompt_mask, neg_mask = ip_adapter.encode_prompt(
face_image=face_ref_image, # PIL Image or None
prompt=effective_prompt,
negative_prompt=negative_prompt,
ip_scale=ip_scale, # 0.0 β†’ 1.0
)
result = pipe(
...
prompt_embeds=prompt_embeds,
negative_prompt_embeds=neg_embeds,
prompt_attention_mask=prompt_mask,
negative_prompt_attention_mask=neg_mask,
)
"""
from __future__ import annotations
import math
from pathlib import Path
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import AutoProcessor, SiglipVisionModel
# ── Perceiver resampler (unchanged from original β€” SD3.5 weights load here) ───
class _FeedForward(nn.Module):
def __init__(self, dim: int, mult: int = 4):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * mult, bias=False),
nn.GELU(),
nn.Linear(dim * mult, dim, bias=False),
)
def forward(self, x):
return self.net(x)
def _reshape(t: torch.Tensor, heads: int) -> torch.Tensor:
b, n, d = t.shape
return t.reshape(b, n, heads, d // heads).transpose(1, 2)
class _PerceiverAttention(nn.Module):
def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8):
super().__init__()
self.heads = heads
inner = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner, bias=False)
self.to_kv = nn.Linear(dim, inner * 2, bias=False)
self.to_out = nn.Linear(inner, dim, bias=False)
def forward(self, x: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
x = self.norm1(x)
latents = self.norm2(latents)
q = _reshape(self.to_q(latents), self.heads)
kv_in = torch.cat([x, latents], dim=1)
k, v = self.to_kv(kv_in).chunk(2, dim=-1)
k, v = _reshape(k, self.heads), _reshape(v, self.heads)
out = F.scaled_dot_product_attention(q, k, v)
out = out.transpose(1, 2).reshape(latents.shape[0], -1, self.to_out.in_features)
return self.to_out(out) + latents
class TimeResampler(nn.Module):
"""
Perceiver resampler β€” architecture matches image_proj.* in
InstantX/SD3.5-Large-IP-Adapter ip-adapter.bin so weights load cleanly.
Output: (batch, num_queries=8, output_dim=1024)
"""
def __init__(
self,
dim: int = 1024,
depth: int = 8,
dim_head: int = 64,
heads: int = 16,
num_queries: int = 8,
embedding_dim: int = 1152, # SigLIP2 so400m hidden size
output_dim: int = 1024,
ff_mult: int = 4,
timestep_in_dim: int = 320,
timestep_flip_sin_to_cos: bool = True,
timestep_freq_shift: int = 0,
):
super().__init__()
from diffusers.models.embeddings import Timesteps, TimestepEmbedding
self.num_queries = num_queries
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
self.t_emb = TimestepEmbedding(timestep_in_dim, dim)
self.layers = nn.ModuleList([
nn.ModuleList([
_PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
_FeedForward(dim=dim, mult=ff_mult),
nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim)),
])
for _ in range(depth)
])
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
t = self.time_proj(timestep.flatten()).to(x.dtype)
t_emb = self.t_emb(t)
latents = self.latents.expand(x.size(0), -1, -1).clone()
x = self.proj_in(x)
for attn, ff, adaln in self.layers:
s_msa, c_msa, s_mlp, c_mlp = adaln(t_emb).chunk(4, dim=-1)
latents = latents * (1 + c_msa[:, None]) + s_msa[:, None]
latents = attn(x, latents)
latents = latents * (1 + c_mlp[:, None]) + s_mlp[:, None]
latents = ff(latents) + latents
return self.norm_out(self.proj_out(latents)) # (B, 8, 1024)
# ── Main class ─────────────────────────────────────────────────────────────────
class WanIPAdapter:
"""
Zero-shot face conditioning for WAN I2V via T5 cross-attention injection.
No transformer patching. Face tokens are appended to prompt_embeds and
WAN's existing cross-attention handles the rest.
"""
_IP_ADAPTER_REPO = "InstantX/SD3.5-Large-IP-Adapter"
_IP_ADAPTER_FILE = "ip-adapter.bin"
_VISION_MODEL = "google/siglip-so400m-patch14-384"
# WAN transformer cross-attention dim (text_dim in WanTransformer3DModel)
_T5_DIM = 4096
# TimeResampler output dim
_RESAMPLER_DIM = 1024
# Number of face tokens appended to the T5 sequence
_NUM_FACE_TOKENS = 8
def __init__(
self,
pipe,
device: torch.device,
dtype: torch.dtype = torch.bfloat16,
cache_dir: str = "/data/ip_adapter",
):
self.pipe = pipe
self.device = device
self.dtype = dtype
self._load_vision_encoder()
self._load_resampler(cache_dir)
self._build_proj_face()
print("[IP-Adapter] ready β€” T5-concat mode, no transformer patching")
# ── setup ──────────────────────────────────────────────────────────────────
def _load_vision_encoder(self):
print("[IP-Adapter] loading SigLIP vision encoder…")
self.vis_proc = AutoProcessor.from_pretrained(self._VISION_MODEL)
self.vis_model = SiglipVisionModel.from_pretrained(
self._VISION_MODEL, torch_dtype=self.dtype,
).to(self.device)
self.vis_model.eval()
print("[IP-Adapter] SigLIP loaded")
def _load_resampler(self, cache_dir: str):
print("[IP-Adapter] loading TimeResampler (SD3.5 ip-adapter.bin)…")
ckpt = hf_hub_download(
repo_id=self._IP_ADAPTER_REPO,
filename=self._IP_ADAPTER_FILE,
local_dir=cache_dir,
)
state = torch.load(ckpt, map_location="cpu", weights_only=True)
img_proj = {
k[len("image_proj."):]: v
for k, v in state.items()
if k.startswith("image_proj.")
}
self.resampler = TimeResampler().to(self.device, self.dtype)
missing, _ = self.resampler.load_state_dict(img_proj, strict=False)
if missing:
print(f"[IP-Adapter] resampler missing keys ({len(missing)}): {missing[:4]}…")
self.resampler.eval()
print("[IP-Adapter] resampler loaded")
def _build_proj_face(self):
"""
Fixed linear projection: resampler output (1024) β†’ T5 space (4096).
Xavier-uniform init so face tokens land at reasonable magnitude relative
to T5 embeddings. This projection is never trained β€” it's a fixed
consistent mapping that preserves the resampler's relative geometry.
"""
self.proj_face = nn.Linear(self._RESAMPLER_DIM, self._T5_DIM, bias=False)
nn.init.xavier_uniform_(self.proj_face.weight)
self.proj_face = self.proj_face.to(self.device, self.dtype)
self.proj_face.eval()
n_params = self.proj_face.weight.numel()
print(f"[IP-Adapter] proj_face built ({n_params:,} params, xavier_uniform, frozen)")
# ── encoding ───────────────────────────────────────────────────────────────
@torch.no_grad()
def _encode_face_tokens(self, image: Image.Image, timestep: int = 500) -> torch.Tensor:
"""
Encode *image* β†’ (1, 8, 4096) face tokens in T5 space.
The timestep passed to the TimeResampler controls which denoising
stage the resampler "thinks" it's at. 500 (mid-point) is a reasonable
default; lower values produce more detail-focused tokens.
"""
inputs = self.vis_proc(images=image, return_tensors="pt").to(self.device)
vis_out = self.vis_model(**inputs)
# Use patch tokens (last_hidden_state) rather than pooled for spatial detail
vis_feats = vis_out.last_hidden_state.to(self.dtype) # (1, N, 1152)
t = torch.tensor([timestep], device=self.device, dtype=torch.long)
emb = self.resampler(vis_feats, t) # (1, 8, 1024)
return self.proj_face(emb) # (1, 8, 4096)
# ── main API ───────────────────────────────────────────────────────────────
def encode_prompt(
self,
face_image: Optional[Image.Image],
prompt: str,
negative_prompt: str = "",
ip_scale: float = 0.6,
num_videos_per_prompt: int = 1,
do_classifier_free_guidance: bool = False,
timestep: int = 500,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Returns (prompt_embeds, negative_prompt_embeds,
prompt_attention_mask, negative_attention_mask)
ready to pass directly to pipe().
If *face_image* is None or *ip_scale* == 0, returns vanilla text embeds.
ip_scale blends face tokens into the prompt by scaling them before concat.
Scale of 1.0 = full face signal; 0.5 = half strength.
"""
# ── text embeddings ────────────────────────────────────────────────────
(
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_attention_mask,
) = self.pipe.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt if do_classifier_free_guidance else None,
do_classifier_free_guidance=do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
device=self.device,
)
if face_image is None or ip_scale == 0.0:
return (
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_attention_mask,
)
# ── face tokens ────────────────────────────────────────────────────────
face_tokens = self._encode_face_tokens(face_image, timestep=timestep)
# Repeat for batch if needed
B = prompt_embeds.shape[0]
if B > 1:
face_tokens = face_tokens.expand(B, -1, -1)
# Scale face tokens β€” controls identity signal strength
face_tokens = face_tokens * ip_scale
# Append to prompt embeds
prompt_embeds = torch.cat([prompt_embeds, face_tokens], dim=1)
face_ones = torch.ones(B, self._NUM_FACE_TOKENS, device=self.device,
dtype=prompt_attention_mask.dtype)
prompt_attention_mask = torch.cat([prompt_attention_mask, face_ones], dim=1)
# For negative: append zeros (face-neutral, not anti-face)
if negative_prompt_embeds is not None:
B_neg = negative_prompt_embeds.shape[0]
neg_face = torch.zeros(B_neg, self._NUM_FACE_TOKENS, self._T5_DIM,
device=self.device, dtype=self.dtype)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, neg_face], dim=1)
neg_ones = torch.ones(B_neg, self._NUM_FACE_TOKENS, device=self.device,
dtype=negative_attention_mask.dtype)
negative_attention_mask = torch.cat([negative_attention_mask, neg_ones], dim=1)
print(f"[IP-Adapter] face tokens appended β€” "
f"prompt_embeds: {prompt_embeds.shape}, scale={ip_scale:.2f}")
return (
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_attention_mask,
)