Fix for enable_gqa error + WAN working on CPU (AMD 16GB)

#32
by amorecaro - opened

I am not a developer, just sharing a working setup in case it helps others.

System:

  • AMD Ryzen 5 7640HS
  • 16GB RAM
  • CPU only (no GPU)
  • Windows

Problem:
TypeError: scaled_dot_product_attention() got an unexpected keyword argument 'enable_gqa'

Solution:
Patch scaled_dot_product_attention before loading the pipeline:

import torch.nn.functional as F

_orig_sdpa = F.scaled_dot_product_attention

def _sdpa_drop_enable_gqa(*args, **kwargs):
kwargs.pop("enable_gqa", None)
return _orig_sdpa(*args, **kwargs)

F.scaled_dot_product_attention = _sdpa_drop_enable_gqa

Settings:

  • 256x256
  • 5 frames
  • 8 steps
  • guidance 3.5

Result:
Working and stable on CPU.Performance:
~3 minutes for 5 frames on CPU
Working and stable on CPU (tested on AMD Ryzen 5 7640HS).

Just sharing in case it helps someone.

Sign up or log in to comment