Daankular commited on
Commit
395e472
Β·
verified Β·
1 Parent(s): 1b01b22

Update ip_adapter.py

Browse files
Files changed (1) hide show
  1. ip_adapter.py +198 -209
ip_adapter.py CHANGED
@@ -1,21 +1,53 @@
1
  """
2
- WAN 2.1 IP-Adapter β€” diffusers-native port of kaaskoek232/IPAdapterWAN.
3
-
4
- Architecture
5
- SigLIP2 so400m (1152-d) β†’ TimeResampler (1024-d, 8 queries)
6
- β†’ per-block WanIPAttnProcessor injected into every self-attention of
7
- pipe.transformer
8
-
9
- Weights
10
- Resampler : loaded from InstantX/SD3.5-Large-IP-Adapter ip-adapter.bin
11
- key prefix "image_proj" (architecture-matched)
12
- IP proj : to_k_ip / to_v_ip initialised from the model's own to_k / to_v
13
- weights (zero-shot reference-attention style β€” works without
14
- Wan-specific training and produces real identity signal)
15
-
16
- LoRA compatibility
17
- IP processors sit on top of whatever to_q/to_k/to_v the LoRA has patched;
18
- they are orthogonal (IP adds extra KV, LoRA modifies weight matrices).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  """
20
 
21
  from __future__ import annotations
@@ -27,20 +59,12 @@ from typing import Optional
27
  import torch
28
  import torch.nn as nn
29
  import torch.nn.functional as F
30
- from einops import rearrange
31
  from huggingface_hub import hf_hub_download
32
  from PIL import Image
33
  from transformers import AutoProcessor, SiglipVisionModel
34
 
35
 
36
- # ── Helpers ────────────────────────────────────────────────────────────────────
37
-
38
- def _reshape(t: torch.Tensor, heads: int) -> torch.Tensor:
39
- b, n, d = t.shape
40
- return t.reshape(b, n, heads, d // heads).transpose(1, 2)
41
-
42
-
43
- # ── Perceiver / TimeResampler (matches SD3.5 ip-adapter.bin image_proj.*) ─────
44
 
45
  class _FeedForward(nn.Module):
46
  def __init__(self, dim: int, mult: int = 4):
@@ -56,16 +80,21 @@ class _FeedForward(nn.Module):
56
  return self.net(x)
57
 
58
 
 
 
 
 
 
59
  class _PerceiverAttention(nn.Module):
60
  def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8):
61
  super().__init__()
62
- self.heads = heads
63
- inner = dim_head * heads
64
- self.norm1 = nn.LayerNorm(dim)
65
- self.norm2 = nn.LayerNorm(dim)
66
- self.to_q = nn.Linear(dim, inner, bias=False)
67
- self.to_kv = nn.Linear(dim, inner * 2, bias=False)
68
- self.to_out = nn.Linear(inner, dim, bias=False)
69
 
70
  def forward(self, x: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
71
  x = self.norm1(x)
@@ -80,10 +109,10 @@ class _PerceiverAttention(nn.Module):
80
 
81
 
82
  class TimeResampler(nn.Module):
83
- """Perceiver resampler with adaLN timestep conditioning.
84
-
85
- Architecture mirrors the image_proj section of
86
- InstantX/SD3.5-Large-IP-Adapter ip-adapter.bin so its weights load cleanly.
87
  """
88
 
89
  def __init__(
@@ -93,7 +122,7 @@ class TimeResampler(nn.Module):
93
  dim_head: int = 64,
94
  heads: int = 16,
95
  num_queries: int = 8,
96
- embedding_dim: int = 1152, # SigLIP2 so400m
97
  output_dim: int = 1024,
98
  ff_mult: int = 4,
99
  timestep_in_dim: int = 320,
@@ -111,20 +140,16 @@ class TimeResampler(nn.Module):
111
  nn.ModuleList([
112
  _PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
113
  _FeedForward(dim=dim, mult=ff_mult),
114
- nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim)), # adaLN
115
  ])
116
  for _ in range(depth)
117
  ])
118
  self.proj_out = nn.Linear(dim, output_dim)
119
  self.norm_out = nn.LayerNorm(output_dim)
120
 
121
- def forward(
122
- self,
123
- x: torch.Tensor,
124
- timestep: torch.Tensor,
125
- ) -> tuple[torch.Tensor, torch.Tensor]:
126
  t = self.time_proj(timestep.flatten()).to(x.dtype)
127
- t_emb = self.t_emb(t) # (B, dim)
128
  latents = self.latents.expand(x.size(0), -1, -1).clone()
129
  x = self.proj_in(x)
130
  for attn, ff, adaln in self.layers:
@@ -133,96 +158,30 @@ class TimeResampler(nn.Module):
133
  latents = attn(x, latents)
134
  latents = latents * (1 + c_mlp[:, None]) + s_mlp[:, None]
135
  latents = ff(latents) + latents
136
- latents = self.norm_out(self.proj_out(latents))
137
- return latents, t_emb
138
-
139
-
140
- # ── Per-block attention processor ─────────────────────────────────────────────
141
-
142
- class WanIPAttnProcessor:
143
- """Wraps an existing Attention processor and adds IP face KV injection.
144
-
145
- The IP keys/values are initialised from the model's own to_k / to_v weights
146
- (zero-shot reference-attention), so no separate IP training is needed.
147
- Conditioned frames attend to the face tokens in every self-attention block.
148
- """
149
-
150
- def __init__(
151
- self,
152
- original_processor,
153
- to_k_ip: nn.Linear,
154
- to_v_ip: nn.Linear,
155
- norm_k_ip: Optional[nn.Module] = None,
156
- norm_v_ip: Optional[nn.Module] = None,
157
- scale: float = 1.0,
158
- ):
159
- self.original = original_processor
160
- self.to_k_ip = to_k_ip
161
- self.to_v_ip = to_v_ip
162
- self.norm_k_ip = norm_k_ip
163
- self.norm_v_ip = norm_v_ip
164
- self.scale = scale
165
- # Set before each pipeline call; cleared after.
166
- self.ip_hidden_states: Optional[torch.Tensor] = None
167
-
168
- def __call__(self, attn, hidden_states, *args, **kwargs):
169
- out = self.original(attn, hidden_states, *args, **kwargs)
170
-
171
- if self.ip_hidden_states is None or self.scale == 0:
172
- return out
173
-
174
- hs = self.ip_hidden_states
175
- h = attn.heads
176
- # Compute Q from hidden_states (re-use the model's normalised projection)
177
- q = attn.to_q(hidden_states)
178
- norm_q = getattr(attn, "norm_q", None)
179
- if norm_q is not None:
180
- q = norm_q(q)
181
-
182
- # Compute IP K / V
183
- k_ip = self.to_k_ip(hs)
184
- v_ip = self.to_v_ip(hs)
185
- if self.norm_k_ip is not None:
186
- k_ip = self.norm_k_ip(k_ip)
187
- if self.norm_v_ip is not None:
188
- v_ip = self.norm_v_ip(v_ip)
189
-
190
- q = _reshape(q, h)
191
- k_ip = _reshape(k_ip, h)
192
- v_ip = _reshape(v_ip, h)
193
-
194
- ip_attn = F.scaled_dot_product_attention(q, k_ip, v_ip)
195
- inner_dim = getattr(attn, "inner_dim", q.shape[-1] * h)
196
- ip_attn = ip_attn.transpose(1, 2).reshape(
197
- hidden_states.shape[0], -1, inner_dim
198
- )
199
- ip_attn = attn.to_out[0](ip_attn)
200
- if len(attn.to_out) > 1:
201
- ip_attn = attn.to_out[1](ip_attn)
202
-
203
- return out + ip_attn * self.scale
204
 
205
 
206
  # ── Main class ─────────────────────────────────────────────────────────────────
207
 
208
  class WanIPAdapter:
209
- """Loads the IP-Adapter and patches pipe.transformer for face conditioning.
210
-
211
- Usage inside _init_pipeline():
212
- ip_adapter = WanIPAdapter(pipe, device=pipe.device, dtype=torch.bfloat16)
213
 
214
- Usage inside run_inference() (before pipe()):
215
- if face_ref is not None:
216
- emb = ip_adapter.encode(face_ref, timestep=500)
217
- ip_adapter.set_hidden_states(emb, scale=ip_scale)
218
- result = pipe(...)
219
- ip_adapter.clear_hidden_states()
220
  """
221
 
222
  _IP_ADAPTER_REPO = "InstantX/SD3.5-Large-IP-Adapter"
223
  _IP_ADAPTER_FILE = "ip-adapter.bin"
224
  _VISION_MODEL = "google/siglip-so400m-patch14-384"
225
 
 
 
 
 
 
 
 
226
  def __init__(
227
  self,
228
  pipe,
@@ -236,8 +195,8 @@ class WanIPAdapter:
236
 
237
  self._load_vision_encoder()
238
  self._load_resampler(cache_dir)
239
- self._patch_transformer(pipe.transformer)
240
- print("[IP-Adapter] ready")
241
 
242
  # ── setup ──────────────────────────────────────────────────────────────────
243
 
@@ -245,112 +204,142 @@ class WanIPAdapter:
245
  print("[IP-Adapter] loading SigLIP vision encoder…")
246
  self.vis_proc = AutoProcessor.from_pretrained(self._VISION_MODEL)
247
  self.vis_model = SiglipVisionModel.from_pretrained(
248
- self._VISION_MODEL, torch_dtype=self.dtype
249
  ).to(self.device)
250
  self.vis_model.eval()
251
  print("[IP-Adapter] SigLIP loaded")
252
 
253
  def _load_resampler(self, cache_dir: str):
254
- print("[IP-Adapter] loading TimeResampler from SD3.5 ip-adapter.bin…")
255
  ckpt = hf_hub_download(
256
  repo_id=self._IP_ADAPTER_REPO,
257
  filename=self._IP_ADAPTER_FILE,
258
  local_dir=cache_dir,
259
  )
260
  state = torch.load(ckpt, map_location="cpu", weights_only=True)
261
-
262
- # Detect checkpoint key prefix (ip-adapter.bin uses "image_proj.*")
263
- prefix = "image_proj"
264
  img_proj = {
265
- k[len(prefix) + 1:]: v
266
  for k, v in state.items()
267
- if k.startswith(prefix + ".")
268
  }
269
-
270
  self.resampler = TimeResampler().to(self.device, self.dtype)
271
- missing, unexpected = self.resampler.load_state_dict(img_proj, strict=False)
272
  if missing:
273
  print(f"[IP-Adapter] resampler missing keys ({len(missing)}): {missing[:4]}…")
 
274
  print("[IP-Adapter] resampler loaded")
275
 
276
- def _patch_transformer(self, transformer: nn.Module):
277
- """Replace every self-attention processor with WanIPAttnProcessor."""
278
- self._processors: list[WanIPAttnProcessor] = []
279
-
280
- for name, mod in transformer.named_modules():
281
- if not (hasattr(mod, "processor") and hasattr(mod, "to_k")):
282
- continue
283
-
284
- # Build IP projections mirroring the model's own K/V projections
285
- to_k_ip = nn.Linear(
286
- self.resampler.proj_out.out_features,
287
- mod.to_k.out_features,
288
- bias=False,
289
- ).to(self.device, self.dtype)
290
- to_v_ip = nn.Linear(
291
- self.resampler.proj_out.out_features,
292
- mod.to_v.out_features,
293
- bias=False,
294
- ).to(self.device, self.dtype)
295
-
296
- # Zero-shot init: copy model's own projection weights then scale down
297
- # so the initial IP signal is small but directionally meaningful.
298
- k_w = mod.to_k.weight.data
299
- v_w = mod.to_v.weight.data
300
- out_f, in_f = to_k_ip.weight.shape
301
- # in_f = resampler output (1024); in_f may differ from k_w.shape[1]
302
- # β€” just use kaiming init if shapes differ
303
- if in_f == k_w.shape[1]:
304
- to_k_ip.weight.data.copy_(k_w[:out_f] * 0.01)
305
- to_v_ip.weight.data.copy_(v_w[:out_f] * 0.01)
306
- else:
307
- nn.init.kaiming_uniform_(to_k_ip.weight, a=math.sqrt(5))
308
- nn.init.kaiming_uniform_(to_v_ip.weight, a=math.sqrt(5))
309
- to_k_ip.weight.data *= 0.01
310
- to_v_ip.weight.data *= 0.01
311
-
312
- # Clone existing norms if present
313
- norm_k = mod.norm_k.__class__(mod.norm_k.normalized_shape[0]) \
314
- if hasattr(mod, "norm_k") and mod.norm_k is not None else None
315
- norm_v = mod.norm_v.__class__(mod.norm_v.normalized_shape[0]) \
316
- if hasattr(mod, "norm_v") and mod.norm_v is not None else None
317
- if norm_k is not None:
318
- norm_k = norm_k.to(self.device, self.dtype)
319
- if norm_v is not None:
320
- norm_v = norm_v.to(self.device, self.dtype)
321
-
322
- ip_proc = WanIPAttnProcessor(
323
- original_processor=mod.processor,
324
- to_k_ip=to_k_ip,
325
- to_v_ip=to_v_ip,
326
- norm_k_ip=norm_k,
327
- norm_v_ip=norm_v,
328
- )
329
- mod.processor = ip_proc
330
- self._processors.append(ip_proc)
331
 
332
- print(f"[IP-Adapter] patched {len(self._processors)} attention blocks")
 
 
 
 
 
 
 
 
 
333
 
334
- # ── inference API ─────────────────────────────────────────────────────────
335
 
336
  @torch.no_grad()
337
- def encode(self, image: Image.Image, timestep: int = 500) -> torch.Tensor:
338
- """Encode *image* through SigLIP2 + TimeResampler β†’ (1, 8, 1024)."""
339
- inputs = self.vis_proc(images=image, return_tensors="pt").to(self.device)
 
 
 
 
 
 
340
  vis_out = self.vis_model(**inputs)
341
- # Use last_hidden_state (patch tokens) rather than pooled for richer features
342
  vis_feats = vis_out.last_hidden_state.to(self.dtype) # (1, N, 1152)
343
- t = torch.tensor([timestep], device=self.device, dtype=torch.long)
344
- emb, _ = self.resampler(vis_feats, t) # (1, 8, 1024)
345
- return emb
346
-
347
- def set_hidden_states(self, emb: torch.Tensor, scale: float = 0.6):
348
- """Broadcast *emb* to all processors before a pipe() call."""
349
- for p in self._processors:
350
- p.ip_hidden_states = emb
351
- p.scale = scale
352
-
353
- def clear_hidden_states(self):
354
- """Remove face embeddings after pipe() returns."""
355
- for p in self._processors:
356
- p.ip_hidden_states = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ WAN IP-Adapter β€” zero-shot face conditioning via T5 cross-attention injection.
3
+
4
+ Strategy
5
+ ────────
6
+ Instead of patching WAN's self-attention blocks (which requires trained K/V
7
+ projections that don't exist for WAN), we inject face identity through the
8
+ cross-attention pathway that WAN already uses for text conditioning.
9
+
10
+ Pipeline
11
+ 1. SigLIP2 so400m (1152-d patch tokens)
12
+ ↓ TimeResampler (SD3.5 trained weights, 8 queries β†’ 1024-d)
13
+ ↓ proj_face nn.Linear(1024 β†’ 4096, xavier_uniform init)
14
+ = 8 face tokens in T5 space (1, 8, 4096)
15
+
16
+ 2. These are appended to the T5 prompt_embeds before each pipe() call.
17
+ WAN's cross-attention naturally attends to all tokens in encoder_hidden_states,
18
+ so no transformer surgery is needed.
19
+
20
+ 3. For CFG (guidance_scale > 1), zeros are appended to the negative embeds
21
+ so the unconditional branch is face-neutral, not anti-face.
22
+
23
+ Why this works zero-shot
24
+ ────────────────────────
25
+ The TimeResampler is trained (SD3.5 weights) and produces semantically
26
+ structured 1024-d tokens. The random proj_face (xavier_uniform) is a
27
+ fixed linear map β€” it preserves the relative geometry of the resampler
28
+ space, so the same face always maps to the same region of T5 space and
29
+ similar faces map to nearby regions. WAN's cross-attention sees consistent
30
+ identity tokens for consistent faces.
31
+
32
+ Usage in app.py
33
+ ───────────────
34
+ Init (once, inside _init_pipeline):
35
+ ip_adapter = WanIPAdapter(pipe, device=pipe.device, dtype=torch.bfloat16)
36
+
37
+ Per-generation (inside run_inference, before pipe()):
38
+ prompt_embeds, neg_embeds, prompt_mask, neg_mask = ip_adapter.encode_prompt(
39
+ face_image=face_ref_image, # PIL Image or None
40
+ prompt=effective_prompt,
41
+ negative_prompt=negative_prompt,
42
+ ip_scale=ip_scale, # 0.0 β†’ 1.0
43
+ )
44
+ result = pipe(
45
+ ...
46
+ prompt_embeds=prompt_embeds,
47
+ negative_prompt_embeds=neg_embeds,
48
+ prompt_attention_mask=prompt_mask,
49
+ negative_prompt_attention_mask=neg_mask,
50
+ )
51
  """
52
 
53
  from __future__ import annotations
 
59
  import torch
60
  import torch.nn as nn
61
  import torch.nn.functional as F
 
62
  from huggingface_hub import hf_hub_download
63
  from PIL import Image
64
  from transformers import AutoProcessor, SiglipVisionModel
65
 
66
 
67
+ # ── Perceiver resampler (unchanged from original β€” SD3.5 weights load here) ───
 
 
 
 
 
 
 
68
 
69
  class _FeedForward(nn.Module):
70
  def __init__(self, dim: int, mult: int = 4):
 
80
  return self.net(x)
81
 
82
 
83
+ def _reshape(t: torch.Tensor, heads: int) -> torch.Tensor:
84
+ b, n, d = t.shape
85
+ return t.reshape(b, n, heads, d // heads).transpose(1, 2)
86
+
87
+
88
  class _PerceiverAttention(nn.Module):
89
  def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8):
90
  super().__init__()
91
+ self.heads = heads
92
+ inner = dim_head * heads
93
+ self.norm1 = nn.LayerNorm(dim)
94
+ self.norm2 = nn.LayerNorm(dim)
95
+ self.to_q = nn.Linear(dim, inner, bias=False)
96
+ self.to_kv = nn.Linear(dim, inner * 2, bias=False)
97
+ self.to_out = nn.Linear(inner, dim, bias=False)
98
 
99
  def forward(self, x: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
100
  x = self.norm1(x)
 
109
 
110
 
111
  class TimeResampler(nn.Module):
112
+ """
113
+ Perceiver resampler β€” architecture matches image_proj.* in
114
+ InstantX/SD3.5-Large-IP-Adapter ip-adapter.bin so weights load cleanly.
115
+ Output: (batch, num_queries=8, output_dim=1024)
116
  """
117
 
118
  def __init__(
 
122
  dim_head: int = 64,
123
  heads: int = 16,
124
  num_queries: int = 8,
125
+ embedding_dim: int = 1152, # SigLIP2 so400m hidden size
126
  output_dim: int = 1024,
127
  ff_mult: int = 4,
128
  timestep_in_dim: int = 320,
 
140
  nn.ModuleList([
141
  _PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
142
  _FeedForward(dim=dim, mult=ff_mult),
143
+ nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim)),
144
  ])
145
  for _ in range(depth)
146
  ])
147
  self.proj_out = nn.Linear(dim, output_dim)
148
  self.norm_out = nn.LayerNorm(output_dim)
149
 
150
+ def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
 
 
 
 
151
  t = self.time_proj(timestep.flatten()).to(x.dtype)
152
+ t_emb = self.t_emb(t)
153
  latents = self.latents.expand(x.size(0), -1, -1).clone()
154
  x = self.proj_in(x)
155
  for attn, ff, adaln in self.layers:
 
158
  latents = attn(x, latents)
159
  latents = latents * (1 + c_mlp[:, None]) + s_mlp[:, None]
160
  latents = ff(latents) + latents
161
+ return self.norm_out(self.proj_out(latents)) # (B, 8, 1024)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
 
164
  # ── Main class ─────────────────────────────────────────────────────────────────
165
 
166
  class WanIPAdapter:
167
+ """
168
+ Zero-shot face conditioning for WAN I2V via T5 cross-attention injection.
 
 
169
 
170
+ No transformer patching. Face tokens are appended to prompt_embeds and
171
+ WAN's existing cross-attention handles the rest.
 
 
 
 
172
  """
173
 
174
  _IP_ADAPTER_REPO = "InstantX/SD3.5-Large-IP-Adapter"
175
  _IP_ADAPTER_FILE = "ip-adapter.bin"
176
  _VISION_MODEL = "google/siglip-so400m-patch14-384"
177
 
178
+ # WAN transformer cross-attention dim (text_dim in WanTransformer3DModel)
179
+ _T5_DIM = 4096
180
+ # TimeResampler output dim
181
+ _RESAMPLER_DIM = 1024
182
+ # Number of face tokens appended to the T5 sequence
183
+ _NUM_FACE_TOKENS = 8
184
+
185
  def __init__(
186
  self,
187
  pipe,
 
195
 
196
  self._load_vision_encoder()
197
  self._load_resampler(cache_dir)
198
+ self._build_proj_face()
199
+ print("[IP-Adapter] ready β€” T5-concat mode, no transformer patching")
200
 
201
  # ── setup ──────────────────────────────────────────────────────────────────
202
 
 
204
  print("[IP-Adapter] loading SigLIP vision encoder…")
205
  self.vis_proc = AutoProcessor.from_pretrained(self._VISION_MODEL)
206
  self.vis_model = SiglipVisionModel.from_pretrained(
207
+ self._VISION_MODEL, torch_dtype=self.dtype,
208
  ).to(self.device)
209
  self.vis_model.eval()
210
  print("[IP-Adapter] SigLIP loaded")
211
 
212
  def _load_resampler(self, cache_dir: str):
213
+ print("[IP-Adapter] loading TimeResampler (SD3.5 ip-adapter.bin)…")
214
  ckpt = hf_hub_download(
215
  repo_id=self._IP_ADAPTER_REPO,
216
  filename=self._IP_ADAPTER_FILE,
217
  local_dir=cache_dir,
218
  )
219
  state = torch.load(ckpt, map_location="cpu", weights_only=True)
 
 
 
220
  img_proj = {
221
+ k[len("image_proj."):]: v
222
  for k, v in state.items()
223
+ if k.startswith("image_proj.")
224
  }
 
225
  self.resampler = TimeResampler().to(self.device, self.dtype)
226
+ missing, _ = self.resampler.load_state_dict(img_proj, strict=False)
227
  if missing:
228
  print(f"[IP-Adapter] resampler missing keys ({len(missing)}): {missing[:4]}…")
229
+ self.resampler.eval()
230
  print("[IP-Adapter] resampler loaded")
231
 
232
+ def _build_proj_face(self):
233
+ """
234
+ Fixed linear projection: resampler output (1024) β†’ T5 space (4096).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
+ Xavier-uniform init so face tokens land at reasonable magnitude relative
237
+ to T5 embeddings. This projection is never trained β€” it's a fixed
238
+ consistent mapping that preserves the resampler's relative geometry.
239
+ """
240
+ self.proj_face = nn.Linear(self._RESAMPLER_DIM, self._T5_DIM, bias=False)
241
+ nn.init.xavier_uniform_(self.proj_face.weight)
242
+ self.proj_face = self.proj_face.to(self.device, self.dtype)
243
+ self.proj_face.eval()
244
+ n_params = self.proj_face.weight.numel()
245
+ print(f"[IP-Adapter] proj_face built ({n_params:,} params, xavier_uniform, frozen)")
246
 
247
+ # ── encoding ───────────────────────────────────────────────────────────────
248
 
249
  @torch.no_grad()
250
+ def _encode_face_tokens(self, image: Image.Image, timestep: int = 500) -> torch.Tensor:
251
+ """
252
+ Encode *image* β†’ (1, 8, 4096) face tokens in T5 space.
253
+
254
+ The timestep passed to the TimeResampler controls which denoising
255
+ stage the resampler "thinks" it's at. 500 (mid-point) is a reasonable
256
+ default; lower values produce more detail-focused tokens.
257
+ """
258
+ inputs = self.vis_proc(images=image, return_tensors="pt").to(self.device)
259
  vis_out = self.vis_model(**inputs)
260
+ # Use patch tokens (last_hidden_state) rather than pooled for spatial detail
261
  vis_feats = vis_out.last_hidden_state.to(self.dtype) # (1, N, 1152)
262
+
263
+ t = torch.tensor([timestep], device=self.device, dtype=torch.long)
264
+ emb = self.resampler(vis_feats, t) # (1, 8, 1024)
265
+ return self.proj_face(emb) # (1, 8, 4096)
266
+
267
+ # ── main API ───────────────────────────────────────────────────────────────
268
+
269
+ def encode_prompt(
270
+ self,
271
+ face_image: Optional[Image.Image],
272
+ prompt: str,
273
+ negative_prompt: str = "",
274
+ ip_scale: float = 0.6,
275
+ num_videos_per_prompt: int = 1,
276
+ do_classifier_free_guidance: bool = False,
277
+ timestep: int = 500,
278
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
279
+ """
280
+ Returns (prompt_embeds, negative_prompt_embeds,
281
+ prompt_attention_mask, negative_attention_mask)
282
+ ready to pass directly to pipe().
283
+
284
+ If *face_image* is None or *ip_scale* == 0, returns vanilla text embeds.
285
+
286
+ ip_scale blends face tokens into the prompt by scaling them before concat.
287
+ Scale of 1.0 = full face signal; 0.5 = half strength.
288
+ """
289
+ # ── text embeddings ───────────────────────────────────────────────────��
290
+ (
291
+ prompt_embeds,
292
+ negative_prompt_embeds,
293
+ prompt_attention_mask,
294
+ negative_attention_mask,
295
+ ) = self.pipe.encode_prompt(
296
+ prompt=prompt,
297
+ negative_prompt=negative_prompt if do_classifier_free_guidance else None,
298
+ do_classifier_free_guidance=do_classifier_free_guidance,
299
+ num_videos_per_prompt=num_videos_per_prompt,
300
+ device=self.device,
301
+ )
302
+
303
+ if face_image is None or ip_scale == 0.0:
304
+ return (
305
+ prompt_embeds,
306
+ negative_prompt_embeds,
307
+ prompt_attention_mask,
308
+ negative_attention_mask,
309
+ )
310
+
311
+ # ── face tokens ────────────────────────────────────────────────────────
312
+ face_tokens = self._encode_face_tokens(face_image, timestep=timestep)
313
+ # Repeat for batch if needed
314
+ B = prompt_embeds.shape[0]
315
+ if B > 1:
316
+ face_tokens = face_tokens.expand(B, -1, -1)
317
+
318
+ # Scale face tokens β€” controls identity signal strength
319
+ face_tokens = face_tokens * ip_scale
320
+
321
+ # Append to prompt embeds
322
+ prompt_embeds = torch.cat([prompt_embeds, face_tokens], dim=1)
323
+ face_ones = torch.ones(B, self._NUM_FACE_TOKENS, device=self.device,
324
+ dtype=prompt_attention_mask.dtype)
325
+ prompt_attention_mask = torch.cat([prompt_attention_mask, face_ones], dim=1)
326
+
327
+ # For negative: append zeros (face-neutral, not anti-face)
328
+ if negative_prompt_embeds is not None:
329
+ B_neg = negative_prompt_embeds.shape[0]
330
+ neg_face = torch.zeros(B_neg, self._NUM_FACE_TOKENS, self._T5_DIM,
331
+ device=self.device, dtype=self.dtype)
332
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, neg_face], dim=1)
333
+ neg_ones = torch.ones(B_neg, self._NUM_FACE_TOKENS, device=self.device,
334
+ dtype=negative_attention_mask.dtype)
335
+ negative_attention_mask = torch.cat([negative_attention_mask, neg_ones], dim=1)
336
+
337
+ print(f"[IP-Adapter] face tokens appended β€” "
338
+ f"prompt_embeds: {prompt_embeds.shape}, scale={ip_scale:.2f}")
339
+
340
+ return (
341
+ prompt_embeds,
342
+ negative_prompt_embeds,
343
+ prompt_attention_mask,
344
+ negative_attention_mask,
345
+ )