File size: 2,771 Bytes
5122d12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# Patch unirig_skin.py: replace flash_attn.MHA with a weight-compatible shim
# The checkpoint uses flash_attn MHA weight names (Wq, Wkv, out_proj)
# nn.MultiheadAttention uses in_proj_weight — incompatible with saved checkpoints
# This shim matches flash_attn MHA's weight layout exactly

shim = '''
import torch
import torch.nn as nn
import torch.nn.functional as F

class _FlashMHACompat(nn.Module):
    """
    Drop-in for flash_attn.modules.mha.MHA.
    Matches flash_attn weight layout (Wq, Wkv, out_proj) so checkpoints load cleanly.
    Uses torch SDPA for computation.
    """
    def __init__(self, embed_dim, num_heads, cross_attn=False, **kwargs):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.cross_attn = cross_attn
        # Weight names must match flash_attn MHA exactly
        self.Wq = nn.Linear(embed_dim, embed_dim, bias=True)
        self.Wkv = nn.Linear(embed_dim, 2 * embed_dim, bias=True)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)

    def forward(self, x, x_kv=None):
        B, Sq, D = x.shape
        q = self.Wq(x)
        src = x_kv if (self.cross_attn and x_kv is not None) else x
        kv = self.Wkv(src)
        k, v = kv.chunk(2, dim=-1)
        Skv = src.shape[1]

        def _reshape(t, s):
            return t.view(B, s, self.num_heads, self.head_dim).transpose(1, 2)

        q, k, v = _reshape(q, Sq), _reshape(k, Skv), _reshape(v, Skv)
        out = F.scaled_dot_product_attention(q, k, v)
        out = out.transpose(1, 2).contiguous().view(B, Sq, D)
        return self.out_proj(out)

# Inject into a fake flash_attn module so imports resolve
import sys, types
_fa = types.ModuleType("flash_attn")
_fa_mha = types.ModuleType("flash_attn.modules")
_fa_mha_mha = types.ModuleType("flash_attn.modules.mha")
_fa_mha_mha.MHA = _FlashMHACompat
sys.modules["flash_attn"] = _fa
sys.modules["flash_attn.modules"] = _fa_mha
sys.modules["flash_attn.modules.mha"] = _fa_mha_mha
'''

# Prepend shim to run.py so it injects the fake module before any imports
path = '/root/UniRig/run.py'
with open(path) as f:
    src = f.read()

# Remove previous patch if present
if 'add_safe_globals' in src:
    # Keep the safe_globals patch, add MHA shim before it
    src = shim + src
else:
    src = shim + src

with open(path, 'w') as f:
    f.write(src)

print('run.py patched: flash_attn MHA shim injected')

# Verify the weight names match by checking unirig_skin imports
import subprocess
result = subprocess.run(
    ['grep', '-n', 'flash_attn\|MHA\|Wq\|Wkv', '/root/UniRig/src/model/unirig_skin.py'],
    capture_output=True, text=True
)
print('unirig_skin.py relevant lines:')
print(result.stdout[:500])