File size: 2,771 Bytes
5122d12 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | # Patch unirig_skin.py: replace flash_attn.MHA with a weight-compatible shim
# The checkpoint uses flash_attn MHA weight names (Wq, Wkv, out_proj)
# nn.MultiheadAttention uses in_proj_weight — incompatible with saved checkpoints
# This shim matches flash_attn MHA's weight layout exactly
shim = '''
import torch
import torch.nn as nn
import torch.nn.functional as F
class _FlashMHACompat(nn.Module):
"""
Drop-in for flash_attn.modules.mha.MHA.
Matches flash_attn weight layout (Wq, Wkv, out_proj) so checkpoints load cleanly.
Uses torch SDPA for computation.
"""
def __init__(self, embed_dim, num_heads, cross_attn=False, **kwargs):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.cross_attn = cross_attn
# Weight names must match flash_attn MHA exactly
self.Wq = nn.Linear(embed_dim, embed_dim, bias=True)
self.Wkv = nn.Linear(embed_dim, 2 * embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
def forward(self, x, x_kv=None):
B, Sq, D = x.shape
q = self.Wq(x)
src = x_kv if (self.cross_attn and x_kv is not None) else x
kv = self.Wkv(src)
k, v = kv.chunk(2, dim=-1)
Skv = src.shape[1]
def _reshape(t, s):
return t.view(B, s, self.num_heads, self.head_dim).transpose(1, 2)
q, k, v = _reshape(q, Sq), _reshape(k, Skv), _reshape(v, Skv)
out = F.scaled_dot_product_attention(q, k, v)
out = out.transpose(1, 2).contiguous().view(B, Sq, D)
return self.out_proj(out)
# Inject into a fake flash_attn module so imports resolve
import sys, types
_fa = types.ModuleType("flash_attn")
_fa_mha = types.ModuleType("flash_attn.modules")
_fa_mha_mha = types.ModuleType("flash_attn.modules.mha")
_fa_mha_mha.MHA = _FlashMHACompat
sys.modules["flash_attn"] = _fa
sys.modules["flash_attn.modules"] = _fa_mha
sys.modules["flash_attn.modules.mha"] = _fa_mha_mha
'''
# Prepend shim to run.py so it injects the fake module before any imports
path = '/root/UniRig/run.py'
with open(path) as f:
src = f.read()
# Remove previous patch if present
if 'add_safe_globals' in src:
# Keep the safe_globals patch, add MHA shim before it
src = shim + src
else:
src = shim + src
with open(path, 'w') as f:
f.write(src)
print('run.py patched: flash_attn MHA shim injected')
# Verify the weight names match by checking unirig_skin imports
import subprocess
result = subprocess.run(
['grep', '-n', 'flash_attn\|MHA\|Wq\|Wkv', '/root/UniRig/src/model/unirig_skin.py'],
capture_output=True, text=True
)
print('unirig_skin.py relevant lines:')
print(result.stdout[:500])
|