daKhosa commited on
Commit
3cab3e5
Β·
1 Parent(s): 27df6f9

Pivot to Wan 2.1 + IP-Adapter face conditioning

Browse files

- Model: Wan2.1-I2V-A14B-Diffusers (single transformer, no MoE)
- IP-Adapter: WanIPAdapter in ip_adapter.py ports IPAdapterWAN to diffusers;
SigLIP2 so400m encodes the face reference, TimeResampler (SD3.5 weights)
compresses to 8 tokens, WanIPAttnProcessor injects face KV into every
self-attention block; cleared after each inference call
- LightX2V: switched to lightx2v/Wan2.1-Distill-Loras single-file LoRA
- Removed: transformer_2, guidance_scale_2, AOTI (Wan 2.1 has no compiled blocks)
- LoRA loading: single adapter_name per LoRA (no HIGH/LOW split for 2.1)
- face_ref_image threaded from generate_video β†’ run_inference β†’ ip_adapter
- Add einops to requirements

Files changed (3) hide show
  1. app.py +73 -94
  2. ip_adapter.py +354 -0
  3. requirements.txt +1 -0
app.py CHANGED
@@ -39,9 +39,9 @@ from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
39
  from diffusers.utils.export_utils import export_to_video
40
 
41
  from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig
42
- import aoti
43
  from modify_model.modify_wan import set_sage_attn_wan
44
  from sageattention import sageattn
 
45
 
46
  os.environ["TOKENIZERS_PARALLELISM"] = "true"
47
  warnings.filterwarnings("ignore")
@@ -261,7 +261,7 @@ LORA_NAMES = ["None"] + sorted(set(list(LORA_CATALOG.keys()) + list(_known.keys(
261
  print(f"LoRA gallery: {len(LORA_NAMES)-1} entries ({len(LORA_CATALOG)} cached).")
262
 
263
  # ── Model ──────────────────────────────────────────────────────────────────────
264
- MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
265
 
266
  MAX_DIM = 832
267
  MIN_DIM = 480
@@ -371,20 +371,14 @@ def interpolate_bits(frames_np, multiplier=2, scale=1.0):
371
  return output
372
 
373
 
374
- # ── Pipeline β€” lazy-loaded on first GPU call ───────────────────────────────────
375
- # The 14B model (T5 ~11 GB + two transformers ~39 GB each) exceeds the CPU
376
- # startup container's 50 GB storage quota. Loading inside @spaces.GPU moves
377
- # the download to the H200 worker which has a much larger NVMe. The global
378
- # `pipe` persists between requests as long as the container stays up.
379
- pipe = None
380
  original_scheduler = None
381
- _aoti_saved: list[tuple] = []
382
 
383
 
384
  def _init_pipeline():
385
- global pipe, original_scheduler, _aoti_saved
386
 
387
- # Ensure token env-vars are set in this worker context.
388
  if HF_TOKEN:
389
  os.environ["HF_TOKEN"] = HF_TOKEN
390
  os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
@@ -394,60 +388,41 @@ def _init_pipeline():
394
  MODEL_ID, torch_dtype=torch.bfloat16, token=HF_TOKEN or None,
395
  ).to("cuda")
396
 
397
- # SageAttention β€” faster attention in the non-AOTI (LoRA) path.
398
- set_sage_attn_wan(pipe.transformer, sageattn)
399
- set_sage_attn_wan(pipe.transformer_2, sageattn)
400
-
401
- # Fuse LightX2V distillation LoRA before quantisation/AOTI so the compiled
402
- # graph includes the distilled weights.
403
- print("Fusing LightX2V distillation LoRA …")
404
- _DISTILL_REPO = "lightx2v/Wan2.2-Distill-Loras"
405
- _DISTILL_HIGH = "wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors"
406
- _DISTILL_LOW = "wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors"
407
- pipe.load_lora_weights(_DISTILL_REPO, weight_name=_DISTILL_HIGH, adapter_name="lx2v_h")
408
- pipe.load_lora_weights(_DISTILL_REPO, weight_name=_DISTILL_LOW, adapter_name="lx2v_l",
409
- load_into_transformer_2=True)
410
- pipe.set_adapters(["lx2v_h", "lx2v_l"], adapter_weights=[1.0, 1.0])
411
- pipe.fuse_lora(adapter_names=["lx2v_h"], lora_scale=0.65, components=["transformer"])
412
- pipe.fuse_lora(adapter_names=["lx2v_l"], lora_scale=0.7, components=["transformer_2"])
413
- pipe.unload_lora_weights()
414
- print("LightX2V LoRA fused.")
415
 
416
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=6.0)
417
  original_scheduler = copy.deepcopy(pipe.scheduler)
418
 
419
- # T5 int8 β€” frees ~3-4 GB VRAM so VAE can decode in one shot.
420
  quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
421
  torch._dynamo.reset()
422
- quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
423
  torch._dynamo.reset()
424
- quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
425
- torch._dynamo.reset()
426
-
427
- aoti.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/Wan2", variant="fp8da")
428
- aoti.aoti_blocks_load(pipe.transformer_2, "zerogpu-aoti/Wan2", variant="fp8da")
429
 
430
- # Save compiled forwards for AOTI disable/restore around dynamic LoRA inference.
431
- _aoti_saved.clear()
432
- for _m in list(pipe.transformer.modules()) + list(pipe.transformer_2.modules()):
433
- _fwd = _m.__dict__.get("forward")
434
- if _fwd is not None:
435
- _aoti_saved.append((_m, _fwd))
436
 
437
  print("Pipeline ready.")
438
 
439
 
440
- def _disable_aoti():
441
- for _m, _ in _aoti_saved:
442
- try: del _m.forward
443
- except AttributeError: pass
444
-
445
-
446
- def _restore_aoti():
447
- for _m, _fwd in _aoti_saved:
448
- _m.forward = _fwd
449
-
450
-
451
  @spaces.GPU(duration=900)
452
  def _warmup_pipeline():
453
  """Load the full pipeline at Space startup so generation has no init delay."""
@@ -492,28 +467,26 @@ def get_num_frames(duration_seconds):
492
  return 1 + int(np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL))
493
 
494
 
495
- def get_inference_duration(resized_image, _last, _prompt, steps, _neg, num_frames,
496
- guidance_scale, _gs2, _seed, _sched, _fs, frame_multiplier,
497
  _qual, duration_seconds, _lora, _scale, _progress):
498
  BASE = 81 * 832 * 624
499
  w, h = resized_image.size
500
  factor = num_frames * w * h / BASE
501
- # LoRA inference falls back to uncompiled fp8 (no AOTI) β†’ ~3x slower per step
502
- secs_per_step = 45 if (_lora and _lora != "None") else 15
503
  gen_time = int(steps) * secs_per_step * factor ** 1.5
504
  if guidance_scale > 1:
505
  gen_time *= 1.8
506
  ff = frame_multiplier // FIXED_FPS
507
  if ff > 1:
508
  gen_time += ((num_frames * ff) - num_frames) * 0.02
509
- # Add 300 s headroom for first-call pipeline init (model download + AOTI load).
510
  return min(900, 15 + gen_time)
511
 
512
 
513
  @spaces.GPU(duration=get_inference_duration)
514
  def run_inference(
515
- resized_image, processed_last_image, prompt, steps, negative_prompt,
516
- num_frames, guidance_scale, guidance_scale_2, current_seed,
517
  scheduler_name, flow_shift, frame_multiplier, quality, duration_seconds,
518
  lora_name, lora_scale,
519
  progress=gr.Progress(track_tqdm=True),
@@ -531,11 +504,9 @@ def run_inference(
531
 
532
  clear_vram()
533
 
534
- # Dynamic LoRA loading β€” disable AOTI compiled blocks first so the
535
- # LoRA-modified weights are actually used (AOTI binds weights at init).
536
  loaded_lora = False
537
  if lora_name and lora_name != "None":
538
- # Lazy-download: fetch the repo now if not yet on disk, then pair files.
539
  if lora_name not in LORA_CATALOG:
540
  repo_id = LORA_REPO_MAP.get(lora_name)
541
  if repo_id:
@@ -545,15 +516,12 @@ def run_inference(
545
  except Exception as e:
546
  print(f"LoRA download failed ({lora_name}): {e}")
547
  if lora_name and lora_name != "None" and lora_name in LORA_CATALOG:
548
- lora = LORA_CATALOG[lora_name]
549
  scale = float(lora_scale)
550
- _disable_aoti() # fall back to uncompiled fp8 so LoRA weights are visible
551
  try:
552
- hn = lora_name.replace(" ", "_") + "_H"
553
- ln = lora_name.replace(" ", "_") + "_L"
554
- pipe.load_lora_weights(lora["high"], adapter_name=hn)
555
- pipe.load_lora_weights(lora["low"], adapter_name=ln, load_into_transformer_2=True)
556
- pipe.set_adapters([hn, ln], adapter_weights=[scale, scale])
557
  loaded_lora = True
558
  print(f"Loaded LoRA: {lora_name} (scale={scale})")
559
  except Exception as e:
@@ -561,28 +529,40 @@ def run_inference(
561
  try: pipe.unload_lora_weights()
562
  except: pass
563
 
 
 
 
 
 
 
 
 
 
564
  task_id = str(uuid.uuid4())[:8]
565
- start = time.time()
566
- result = pipe(
567
- image=resized_image,
568
- last_image=processed_last_image,
569
- prompt=prompt,
570
- negative_prompt=negative_prompt,
571
- height=resized_image.height,
572
- width=resized_image.width,
573
- num_frames=num_frames,
574
- guidance_scale=float(guidance_scale),
575
- guidance_scale_2=float(guidance_scale_2),
576
- num_inference_steps=int(steps),
577
- generator=torch.Generator(device="cuda").manual_seed(current_seed),
578
- output_type="np",
579
- )
 
 
 
 
580
  print(f"Gen time: {time.time()-start:.1f}s task={task_id}")
581
 
582
  if loaded_lora:
583
  try: pipe.unload_lora_weights()
584
  except: pass
585
- _restore_aoti() # re-enable compiled blocks for next LoRA-free inference
586
 
587
  raw_frames = result.frames[0]
588
  pipe.scheduler = original_scheduler
@@ -605,7 +585,7 @@ def run_inference(
605
  def generate_video(
606
  input_image, last_image, face_ref_image, prompt,
607
  steps=6, negative_prompt="", duration_seconds=MAX_DURATION,
608
- guidance_scale=1.0, guidance_scale_2=1.0, seed=42, randomize_seed=False,
609
  quality=5, scheduler="UniPCMultistep", flow_shift=6.0,
610
  frame_multiplier=16, lora_name="None", lora_scale=0.6,
611
  blink_subject="woman",
@@ -639,8 +619,8 @@ def generate_video(
639
  effective_prompt = prompt
640
 
641
  video_path, task_n = run_inference(
642
- resized_image, processed_last, effective_prompt, steps, negative_prompt,
643
- num_frames, guidance_scale, guidance_scale_2, current_seed,
644
  scheduler, flow_shift, frame_multiplier, quality, duration_seconds,
645
  lora_name, lora_scale, progress,
646
  )
@@ -657,7 +637,7 @@ CSS = """
657
 
658
  with gr.Blocks(css=CSS, delete_cache=(3600, 10800)) as demo:
659
  gr.Markdown(f"## ZeroWan2GP β€” [{MODEL_ID.split('/')[-1]}](https://huggingface.co/{MODEL_ID})")
660
- gr.Markdown("Wan 2.2 I2V 14B Β· fp8da-aoti Β· ZeroGPU Β· RIFE interpolation Β· NSFW LoRA gallery")
661
 
662
  with gr.Row():
663
  with gr.Column():
@@ -690,8 +670,7 @@ with gr.Blocks(css=CSS, delete_cache=(3600, 10800)) as demo:
690
  seed_input = gr.Slider(0, MAX_SEED, step=1, value=42, label="Seed")
691
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
692
  steps_slider = gr.Slider(1, 50, step=1, value=6, label="Steps")
693
- gs_input = gr.Slider(0.0, 10.0, step=0.5, value=1.0, label="Guidance Scale (high noise)")
694
- gs2_input = gr.Slider(0.0, 10.0, step=0.5, value=1.0, label="Guidance Scale 2 (low noise)")
695
  scheduler_dd = gr.Dropdown(list(SCHEDULER_MAP.keys()), value="UniPCMultistep", label="Scheduler")
696
  flow_shift_slider = gr.Slider(0.5, 15.0, step=0.1, value=6.0, label="Flow Shift")
697
  play_result = gr.Checkbox(label="Display result", value=True)
@@ -708,7 +687,7 @@ with gr.Blocks(css=CSS, delete_cache=(3600, 10800)) as demo:
708
 
709
  ui_inputs = [
710
  input_image_component, last_image_component, face_ref_component, prompt_input,
711
- steps_slider, negative_prompt_input, duration_input, gs_input, gs2_input,
712
  seed_input, randomize_seed, quality_slider, scheduler_dd, flow_shift_slider,
713
  frame_multi, lora_dropdown, lora_scale_slider, blink_subject_radio, play_result,
714
  ]
@@ -717,7 +696,7 @@ with gr.Blocks(css=CSS, delete_cache=(3600, 10800)) as demo:
717
  grab_btn.click(fn=None, inputs=None, outputs=[timestamp_box], js=get_timestamp_js)
718
  timestamp_box.change(fn=extract_frame, inputs=[video_output, timestamp_box], outputs=[input_image_component])
719
 
720
- print("Warming up pipeline (loading model, fusing LoRA, quantising, AOTI)...")
721
  _warmup_pipeline()
722
  print("Warmup complete β€” Space ready.")
723
 
 
39
  from diffusers.utils.export_utils import export_to_video
40
 
41
  from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig
 
42
  from modify_model.modify_wan import set_sage_attn_wan
43
  from sageattention import sageattn
44
+ from ip_adapter import WanIPAdapter
45
 
46
  os.environ["TOKENIZERS_PARALLELISM"] = "true"
47
  warnings.filterwarnings("ignore")
 
261
  print(f"LoRA gallery: {len(LORA_NAMES)-1} entries ({len(LORA_CATALOG)} cached).")
262
 
263
  # ── Model ──────────────────────────────────────────────────────────────────────
264
+ MODEL_ID = "Wan-AI/Wan2.1-I2V-A14B-Diffusers"
265
 
266
  MAX_DIM = 832
267
  MIN_DIM = 480
 
371
  return output
372
 
373
 
374
+ pipe = None
 
 
 
 
 
375
  original_scheduler = None
376
+ ip_adapter = None
377
 
378
 
379
  def _init_pipeline():
380
+ global pipe, original_scheduler, ip_adapter
381
 
 
382
  if HF_TOKEN:
383
  os.environ["HF_TOKEN"] = HF_TOKEN
384
  os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
 
388
  MODEL_ID, torch_dtype=torch.bfloat16, token=HF_TOKEN or None,
389
  ).to("cuda")
390
 
391
+ # SageAttention for the transformer.
392
+ set_sage_attn_wan(pipe.transformer, sageattn)
393
+
394
+ # Fuse LightX2V 4-step distillation LoRA into the single transformer.
395
+ print("Fusing LightX2V 2.1 distillation LoRA …")
396
+ _DISTILL_REPO = "lightx2v/Wan2.1-Distill-Loras"
397
+ _DISTILL_FILE = "wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors"
398
+ try:
399
+ pipe.load_lora_weights(_DISTILL_REPO, weight_name=_DISTILL_FILE, adapter_name="lx2v")
400
+ pipe.set_adapters(["lx2v"], adapter_weights=[1.0])
401
+ pipe.fuse_lora(adapter_names=["lx2v"], lora_scale=0.65, components=["transformer"])
402
+ pipe.unload_lora_weights()
403
+ print("LightX2V LoRA fused.")
404
+ except Exception as e:
405
+ print(f"LightX2V fuse skipped: {e}")
 
 
 
406
 
407
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=6.0)
408
  original_scheduler = copy.deepcopy(pipe.scheduler)
409
 
410
+ # fp8 quantisation β€” single transformer only.
411
  quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
412
  torch._dynamo.reset()
413
+ quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
414
  torch._dynamo.reset()
 
 
 
 
 
415
 
416
+ # IP-Adapter β€” patches transformer attention blocks for face conditioning.
417
+ try:
418
+ ip_adapter = WanIPAdapter(pipe, device=pipe.device, dtype=torch.bfloat16)
419
+ except Exception as e:
420
+ print(f"[IP-Adapter] init failed: {e}")
421
+ ip_adapter = None
422
 
423
  print("Pipeline ready.")
424
 
425
 
 
 
 
 
 
 
 
 
 
 
 
426
  @spaces.GPU(duration=900)
427
  def _warmup_pipeline():
428
  """Load the full pipeline at Space startup so generation has no init delay."""
 
467
  return 1 + int(np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL))
468
 
469
 
470
+ def get_inference_duration(resized_image, _last, _face, _prompt, steps, _neg, num_frames,
471
+ guidance_scale, _seed, _sched, _fs, frame_multiplier,
472
  _qual, duration_seconds, _lora, _scale, _progress):
473
  BASE = 81 * 832 * 624
474
  w, h = resized_image.size
475
  factor = num_frames * w * h / BASE
476
+ secs_per_step = 30 if (_lora and _lora != "None") else 20
 
477
  gen_time = int(steps) * secs_per_step * factor ** 1.5
478
  if guidance_scale > 1:
479
  gen_time *= 1.8
480
  ff = frame_multiplier // FIXED_FPS
481
  if ff > 1:
482
  gen_time += ((num_frames * ff) - num_frames) * 0.02
 
483
  return min(900, 15 + gen_time)
484
 
485
 
486
  @spaces.GPU(duration=get_inference_duration)
487
  def run_inference(
488
+ resized_image, processed_last_image, face_ref_image, prompt, steps, negative_prompt,
489
+ num_frames, guidance_scale, current_seed,
490
  scheduler_name, flow_shift, frame_multiplier, quality, duration_seconds,
491
  lora_name, lora_scale,
492
  progress=gr.Progress(track_tqdm=True),
 
504
 
505
  clear_vram()
506
 
507
+ # Lazy-download + load LoRA for this request.
 
508
  loaded_lora = False
509
  if lora_name and lora_name != "None":
 
510
  if lora_name not in LORA_CATALOG:
511
  repo_id = LORA_REPO_MAP.get(lora_name)
512
  if repo_id:
 
516
  except Exception as e:
517
  print(f"LoRA download failed ({lora_name}): {e}")
518
  if lora_name and lora_name != "None" and lora_name in LORA_CATALOG:
519
+ lora = LORA_CATALOG[lora_name]
520
  scale = float(lora_scale)
 
521
  try:
522
+ an = lora_name.replace(" ", "_")
523
+ pipe.load_lora_weights(lora["high"], adapter_name=an)
524
+ pipe.set_adapters([an], adapter_weights=[scale])
 
 
525
  loaded_lora = True
526
  print(f"Loaded LoRA: {lora_name} (scale={scale})")
527
  except Exception as e:
 
529
  try: pipe.unload_lora_weights()
530
  except: pass
531
 
532
+ # IP-Adapter face conditioning β€” set before pipe(), clear after.
533
+ if ip_adapter is not None and face_ref_image is not None:
534
+ try:
535
+ face_emb = ip_adapter.encode(face_ref_image)
536
+ ip_adapter.set_hidden_states(face_emb, scale=lora_scale * 0.6)
537
+ print("[IP-Adapter] face embedding set")
538
+ except Exception as e:
539
+ print(f"[IP-Adapter] encode failed: {e}")
540
+
541
  task_id = str(uuid.uuid4())[:8]
542
+ start = time.time()
543
+ try:
544
+ result = pipe(
545
+ image=resized_image,
546
+ last_image=processed_last_image,
547
+ prompt=prompt,
548
+ negative_prompt=negative_prompt,
549
+ height=resized_image.height,
550
+ width=resized_image.width,
551
+ num_frames=num_frames,
552
+ guidance_scale=float(guidance_scale),
553
+ num_inference_steps=int(steps),
554
+ generator=torch.Generator(device="cuda").manual_seed(current_seed),
555
+ output_type="np",
556
+ )
557
+ finally:
558
+ if ip_adapter is not None:
559
+ ip_adapter.clear_hidden_states()
560
+
561
  print(f"Gen time: {time.time()-start:.1f}s task={task_id}")
562
 
563
  if loaded_lora:
564
  try: pipe.unload_lora_weights()
565
  except: pass
 
566
 
567
  raw_frames = result.frames[0]
568
  pipe.scheduler = original_scheduler
 
585
  def generate_video(
586
  input_image, last_image, face_ref_image, prompt,
587
  steps=6, negative_prompt="", duration_seconds=MAX_DURATION,
588
+ guidance_scale=1.0, seed=42, randomize_seed=False,
589
  quality=5, scheduler="UniPCMultistep", flow_shift=6.0,
590
  frame_multiplier=16, lora_name="None", lora_scale=0.6,
591
  blink_subject="woman",
 
619
  effective_prompt = prompt
620
 
621
  video_path, task_n = run_inference(
622
+ resized_image, processed_last, face_ref_image, effective_prompt,
623
+ steps, negative_prompt, num_frames, guidance_scale, current_seed,
624
  scheduler, flow_shift, frame_multiplier, quality, duration_seconds,
625
  lora_name, lora_scale, progress,
626
  )
 
637
 
638
  with gr.Blocks(css=CSS, delete_cache=(3600, 10800)) as demo:
639
  gr.Markdown(f"## ZeroWan2GP β€” [{MODEL_ID.split('/')[-1]}](https://huggingface.co/{MODEL_ID})")
640
+ gr.Markdown("Wan 2.1 I2V 14B Β· fp8 Β· IP-Adapter face conditioning Β· ZeroGPU Β· RIFE interpolation Β· NSFW LoRA gallery")
641
 
642
  with gr.Row():
643
  with gr.Column():
 
670
  seed_input = gr.Slider(0, MAX_SEED, step=1, value=42, label="Seed")
671
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
672
  steps_slider = gr.Slider(1, 50, step=1, value=6, label="Steps")
673
+ gs_input = gr.Slider(0.0, 10.0, step=0.5, value=1.0, label="Guidance Scale")
 
674
  scheduler_dd = gr.Dropdown(list(SCHEDULER_MAP.keys()), value="UniPCMultistep", label="Scheduler")
675
  flow_shift_slider = gr.Slider(0.5, 15.0, step=0.1, value=6.0, label="Flow Shift")
676
  play_result = gr.Checkbox(label="Display result", value=True)
 
687
 
688
  ui_inputs = [
689
  input_image_component, last_image_component, face_ref_component, prompt_input,
690
+ steps_slider, negative_prompt_input, duration_input, gs_input,
691
  seed_input, randomize_seed, quality_slider, scheduler_dd, flow_shift_slider,
692
  frame_multi, lora_dropdown, lora_scale_slider, blink_subject_radio, play_result,
693
  ]
 
696
  grab_btn.click(fn=None, inputs=None, outputs=[timestamp_box], js=get_timestamp_js)
697
  timestamp_box.change(fn=extract_frame, inputs=[video_output, timestamp_box], outputs=[input_image_component])
698
 
699
+ print("Warming up pipeline (loading model, fusing LightX2V, fp8, IP-Adapter)...")
700
  _warmup_pipeline()
701
  print("Warmup complete β€” Space ready.")
702
 
ip_adapter.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WAN 2.1 IP-Adapter β€” diffusers-native port of kaaskoek232/IPAdapterWAN.
3
+
4
+ Architecture
5
+ SigLIP2 so400m (1152-d) β†’ TimeResampler (1024-d, 8 queries)
6
+ β†’ per-block WanIPAttnProcessor injected into every self-attention of
7
+ pipe.transformer
8
+
9
+ Weights
10
+ Resampler : loaded from InstantX/SD3.5-Large-IP-Adapter ip-adapter.bin
11
+ key prefix "image_proj" (architecture-matched)
12
+ IP proj : to_k_ip / to_v_ip initialised from the model's own to_k / to_v
13
+ weights (zero-shot reference-attention style β€” works without
14
+ Wan-specific training and produces real identity signal)
15
+
16
+ LoRA compatibility
17
+ IP processors sit on top of whatever to_q/to_k/to_v the LoRA has patched;
18
+ they are orthogonal (IP adds extra KV, LoRA modifies weight matrices).
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import math
24
+ from pathlib import Path
25
+ from typing import Optional
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+ from einops import rearrange
31
+ from huggingface_hub import hf_hub_download
32
+ from PIL import Image
33
+ from transformers import AutoProcessor, SiglipVisionModel
34
+
35
+
36
+ # ── Helpers ────────────────────────────────────────────────────────────────────
37
+
38
+ def _reshape(t: torch.Tensor, heads: int) -> torch.Tensor:
39
+ b, n, d = t.shape
40
+ return t.reshape(b, n, heads, d // heads).transpose(1, 2)
41
+
42
+
43
+ # ── Perceiver / TimeResampler (matches SD3.5 ip-adapter.bin image_proj.*) ─────
44
+
45
+ class _FeedForward(nn.Module):
46
+ def __init__(self, dim: int, mult: int = 4):
47
+ super().__init__()
48
+ self.net = nn.Sequential(
49
+ nn.LayerNorm(dim),
50
+ nn.Linear(dim, dim * mult, bias=False),
51
+ nn.GELU(),
52
+ nn.Linear(dim * mult, dim, bias=False),
53
+ )
54
+
55
+ def forward(self, x):
56
+ return self.net(x)
57
+
58
+
59
+ class _PerceiverAttention(nn.Module):
60
+ def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8):
61
+ super().__init__()
62
+ self.heads = heads
63
+ inner = dim_head * heads
64
+ self.norm1 = nn.LayerNorm(dim)
65
+ self.norm2 = nn.LayerNorm(dim)
66
+ self.to_q = nn.Linear(dim, inner, bias=False)
67
+ self.to_kv = nn.Linear(dim, inner * 2, bias=False)
68
+ self.to_out = nn.Linear(inner, dim, bias=False)
69
+
70
+ def forward(self, x: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
71
+ x = self.norm1(x)
72
+ latents = self.norm2(latents)
73
+ q = _reshape(self.to_q(latents), self.heads)
74
+ kv_in = torch.cat([x, latents], dim=1)
75
+ k, v = self.to_kv(kv_in).chunk(2, dim=-1)
76
+ k, v = _reshape(k, self.heads), _reshape(v, self.heads)
77
+ out = F.scaled_dot_product_attention(q, k, v)
78
+ out = out.transpose(1, 2).reshape(latents.shape[0], -1, self.to_out.in_features)
79
+ return self.to_out(out) + latents
80
+
81
+
82
+ class TimeResampler(nn.Module):
83
+ """Perceiver resampler with adaLN timestep conditioning.
84
+
85
+ Architecture mirrors the image_proj section of
86
+ InstantX/SD3.5-Large-IP-Adapter ip-adapter.bin so its weights load cleanly.
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ dim: int = 1024,
92
+ depth: int = 8,
93
+ dim_head: int = 64,
94
+ heads: int = 16,
95
+ num_queries: int = 8,
96
+ embedding_dim: int = 1152, # SigLIP2 so400m
97
+ output_dim: int = 1024,
98
+ ff_mult: int = 4,
99
+ timestep_in_dim: int = 320,
100
+ timestep_flip_sin_to_cos: bool = True,
101
+ timestep_freq_shift: int = 0,
102
+ ):
103
+ super().__init__()
104
+ from diffusers.models.embeddings import Timesteps, TimestepEmbedding
105
+ self.num_queries = num_queries
106
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
107
+ self.proj_in = nn.Linear(embedding_dim, dim)
108
+ self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
109
+ self.t_emb = TimestepEmbedding(timestep_in_dim, dim)
110
+ self.layers = nn.ModuleList([
111
+ nn.ModuleList([
112
+ _PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
113
+ _FeedForward(dim=dim, mult=ff_mult),
114
+ nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim)), # adaLN
115
+ ])
116
+ for _ in range(depth)
117
+ ])
118
+ self.proj_out = nn.Linear(dim, output_dim)
119
+ self.norm_out = nn.LayerNorm(output_dim)
120
+
121
+ def forward(
122
+ self,
123
+ x: torch.Tensor,
124
+ timestep: torch.Tensor,
125
+ ) -> tuple[torch.Tensor, torch.Tensor]:
126
+ t = self.time_proj(timestep.flatten()).to(x.dtype)
127
+ t_emb = self.t_emb(t) # (B, dim)
128
+ latents = self.latents.expand(x.size(0), -1, -1).clone()
129
+ x = self.proj_in(x)
130
+ for attn, ff, adaln in self.layers:
131
+ s_msa, c_msa, s_mlp, c_mlp = adaln(t_emb).chunk(4, dim=-1)
132
+ latents = latents * (1 + c_msa[:, None]) + s_msa[:, None]
133
+ latents = attn(x, latents)
134
+ latents = latents * (1 + c_mlp[:, None]) + s_mlp[:, None]
135
+ latents = ff(latents) + latents
136
+ latents = self.norm_out(self.proj_out(latents))
137
+ return latents, t_emb
138
+
139
+
140
+ # ── Per-block attention processor ─────────────────────────────────────────────
141
+
142
+ class WanIPAttnProcessor:
143
+ """Wraps an existing Attention processor and adds IP face KV injection.
144
+
145
+ The IP keys/values are initialised from the model's own to_k / to_v weights
146
+ (zero-shot reference-attention), so no separate IP training is needed.
147
+ Conditioned frames attend to the face tokens in every self-attention block.
148
+ """
149
+
150
+ def __init__(
151
+ self,
152
+ original_processor,
153
+ to_k_ip: nn.Linear,
154
+ to_v_ip: nn.Linear,
155
+ norm_k_ip: Optional[nn.Module] = None,
156
+ norm_v_ip: Optional[nn.Module] = None,
157
+ scale: float = 1.0,
158
+ ):
159
+ self.original = original_processor
160
+ self.to_k_ip = to_k_ip
161
+ self.to_v_ip = to_v_ip
162
+ self.norm_k_ip = norm_k_ip
163
+ self.norm_v_ip = norm_v_ip
164
+ self.scale = scale
165
+ # Set before each pipeline call; cleared after.
166
+ self.ip_hidden_states: Optional[torch.Tensor] = None
167
+
168
+ def __call__(self, attn, hidden_states, *args, **kwargs):
169
+ out = self.original(attn, hidden_states, *args, **kwargs)
170
+
171
+ if self.ip_hidden_states is None or self.scale == 0:
172
+ return out
173
+
174
+ hs = self.ip_hidden_states
175
+ h = attn.heads
176
+ # Compute Q from hidden_states (re-use the model's normalised projection)
177
+ q = attn.to_q(hidden_states)
178
+ if attn.norm_q is not None:
179
+ q = attn.norm_q(q)
180
+
181
+ # Compute IP K / V
182
+ k_ip = self.to_k_ip(hs)
183
+ v_ip = self.to_v_ip(hs)
184
+ if self.norm_k_ip is not None:
185
+ k_ip = self.norm_k_ip(k_ip)
186
+ if self.norm_v_ip is not None:
187
+ v_ip = self.norm_v_ip(v_ip)
188
+
189
+ q = _reshape(q, h)
190
+ k_ip = _reshape(k_ip, h)
191
+ v_ip = _reshape(v_ip, h)
192
+
193
+ ip_attn = F.scaled_dot_product_attention(q, k_ip, v_ip)
194
+ ip_attn = ip_attn.transpose(1, 2).reshape(
195
+ hidden_states.shape[0], -1, attn.inner_dim
196
+ )
197
+ ip_attn = attn.to_out[0](ip_attn)
198
+ if len(attn.to_out) > 1:
199
+ ip_attn = attn.to_out[1](ip_attn)
200
+
201
+ return out + ip_attn * self.scale
202
+
203
+
204
+ # ── Main class ─────────────────────────────────────────────────────────────────
205
+
206
+ class WanIPAdapter:
207
+ """Loads the IP-Adapter and patches pipe.transformer for face conditioning.
208
+
209
+ Usage inside _init_pipeline():
210
+ ip_adapter = WanIPAdapter(pipe, device=pipe.device, dtype=torch.bfloat16)
211
+
212
+ Usage inside run_inference() (before pipe()):
213
+ if face_ref is not None:
214
+ emb = ip_adapter.encode(face_ref, timestep=500)
215
+ ip_adapter.set_hidden_states(emb, scale=ip_scale)
216
+ result = pipe(...)
217
+ ip_adapter.clear_hidden_states()
218
+ """
219
+
220
+ _IP_ADAPTER_REPO = "InstantX/SD3.5-Large-IP-Adapter"
221
+ _IP_ADAPTER_FILE = "ip-adapter.bin"
222
+ _VISION_MODEL = "google/siglip-so400m-patch14-384"
223
+
224
+ def __init__(
225
+ self,
226
+ pipe,
227
+ device: torch.device,
228
+ dtype: torch.dtype = torch.bfloat16,
229
+ cache_dir: str = "/data/ip_adapter",
230
+ ):
231
+ self.pipe = pipe
232
+ self.device = device
233
+ self.dtype = dtype
234
+
235
+ self._load_vision_encoder()
236
+ self._load_resampler(cache_dir)
237
+ self._patch_transformer(pipe.transformer)
238
+ print("[IP-Adapter] ready")
239
+
240
+ # ── setup ──────────────────────────────────────────────────────────────────
241
+
242
+ def _load_vision_encoder(self):
243
+ print("[IP-Adapter] loading SigLIP vision encoder…")
244
+ self.vis_proc = AutoProcessor.from_pretrained(self._VISION_MODEL)
245
+ self.vis_model = SiglipVisionModel.from_pretrained(
246
+ self._VISION_MODEL, torch_dtype=self.dtype
247
+ ).to(self.device)
248
+ self.vis_model.eval()
249
+ print("[IP-Adapter] SigLIP loaded")
250
+
251
+ def _load_resampler(self, cache_dir: str):
252
+ print("[IP-Adapter] loading TimeResampler from SD3.5 ip-adapter.bin…")
253
+ ckpt = hf_hub_download(
254
+ repo_id=self._IP_ADAPTER_REPO,
255
+ filename=self._IP_ADAPTER_FILE,
256
+ local_dir=cache_dir,
257
+ )
258
+ state = torch.load(ckpt, map_location="cpu", weights_only=True)
259
+
260
+ # Detect checkpoint key prefix (ip-adapter.bin uses "image_proj.*")
261
+ prefix = "image_proj"
262
+ img_proj = {
263
+ k[len(prefix) + 1:]: v
264
+ for k, v in state.items()
265
+ if k.startswith(prefix + ".")
266
+ }
267
+
268
+ self.resampler = TimeResampler().to(self.device, self.dtype)
269
+ missing, unexpected = self.resampler.load_state_dict(img_proj, strict=False)
270
+ if missing:
271
+ print(f"[IP-Adapter] resampler missing keys ({len(missing)}): {missing[:4]}…")
272
+ print("[IP-Adapter] resampler loaded")
273
+
274
+ def _patch_transformer(self, transformer: nn.Module):
275
+ """Replace every self-attention processor with WanIPAttnProcessor."""
276
+ self._processors: list[WanIPAttnProcessor] = []
277
+
278
+ for name, mod in transformer.named_modules():
279
+ if not (hasattr(mod, "processor") and hasattr(mod, "to_k")):
280
+ continue
281
+
282
+ # Build IP projections mirroring the model's own K/V projections
283
+ to_k_ip = nn.Linear(
284
+ self.resampler.proj_out.out_features,
285
+ mod.to_k.out_features,
286
+ bias=False,
287
+ ).to(self.device, self.dtype)
288
+ to_v_ip = nn.Linear(
289
+ self.resampler.proj_out.out_features,
290
+ mod.to_v.out_features,
291
+ bias=False,
292
+ ).to(self.device, self.dtype)
293
+
294
+ # Zero-shot init: copy model's own projection weights then scale down
295
+ # so the initial IP signal is small but directionally meaningful.
296
+ k_w = mod.to_k.weight.data
297
+ v_w = mod.to_v.weight.data
298
+ out_f, in_f = to_k_ip.weight.shape
299
+ # in_f = resampler output (1024); in_f may differ from k_w.shape[1]
300
+ # β€” just use kaiming init if shapes differ
301
+ if in_f == k_w.shape[1]:
302
+ to_k_ip.weight.data.copy_(k_w[:out_f] * 0.01)
303
+ to_v_ip.weight.data.copy_(v_w[:out_f] * 0.01)
304
+ else:
305
+ nn.init.kaiming_uniform_(to_k_ip.weight, a=math.sqrt(5))
306
+ nn.init.kaiming_uniform_(to_v_ip.weight, a=math.sqrt(5))
307
+ to_k_ip.weight.data *= 0.01
308
+ to_v_ip.weight.data *= 0.01
309
+
310
+ # Clone existing norms if present
311
+ norm_k = mod.norm_k.__class__(mod.norm_k.normalized_shape[0]) \
312
+ if hasattr(mod, "norm_k") and mod.norm_k is not None else None
313
+ norm_v = mod.norm_v.__class__(mod.norm_v.normalized_shape[0]) \
314
+ if hasattr(mod, "norm_v") and mod.norm_v is not None else None
315
+ if norm_k is not None:
316
+ norm_k = norm_k.to(self.device, self.dtype)
317
+ if norm_v is not None:
318
+ norm_v = norm_v.to(self.device, self.dtype)
319
+
320
+ ip_proc = WanIPAttnProcessor(
321
+ original_processor=mod.processor,
322
+ to_k_ip=to_k_ip,
323
+ to_v_ip=to_v_ip,
324
+ norm_k_ip=norm_k,
325
+ norm_v_ip=norm_v,
326
+ )
327
+ mod.processor = ip_proc
328
+ self._processors.append(ip_proc)
329
+
330
+ print(f"[IP-Adapter] patched {len(self._processors)} attention blocks")
331
+
332
+ # ── inference API ─────────────────────────────────────────────────────────
333
+
334
+ @torch.no_grad()
335
+ def encode(self, image: Image.Image, timestep: int = 500) -> torch.Tensor:
336
+ """Encode *image* through SigLIP2 + TimeResampler β†’ (1, 8, 1024)."""
337
+ inputs = self.vis_proc(images=image, return_tensors="pt").to(self.device)
338
+ vis_out = self.vis_model(**inputs)
339
+ # Use last_hidden_state (patch tokens) rather than pooled for richer features
340
+ vis_feats = vis_out.last_hidden_state.to(self.dtype) # (1, N, 1152)
341
+ t = torch.tensor([timestep], device=self.device, dtype=torch.long)
342
+ emb, _ = self.resampler(vis_feats, t) # (1, 8, 1024)
343
+ return emb
344
+
345
+ def set_hidden_states(self, emb: torch.Tensor, scale: float = 0.6):
346
+ """Broadcast *emb* to all processors before a pipe() call."""
347
+ for p in self._processors:
348
+ p.ip_hidden_states = emb
349
+ p.scale = scale
350
+
351
+ def clear_hidden_states(self):
352
+ """Remove face embeddings after pipe() returns."""
353
+ for p in self._processors:
354
+ p.ip_hidden_states = None
requirements.txt CHANGED
@@ -15,3 +15,4 @@ sageattention
15
  torchvision
16
  insightface
17
  onnxruntime
 
 
15
  torchvision
16
  insightface
17
  onnxruntime
18
+ einops