File size: 1,391 Bytes
7b75b7e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | # Restore unirig_skin.py to use flash_attn MHA (will be served by our shim in run.py)
# Our earlier patch replaced flash_attn MHA with nn.MultiheadAttention which breaks
# checkpoint loading because weight names don't match
path = '/root/UniRig/src/model/unirig_skin.py'
with open(path) as f:
src = f.read()
# Restore flash_attn import (remove it if replaced, add it back)
if 'from flash_attn.modules.mha import MHA' not in src:
# Add after the last import line before class definitions
src = src.replace(
'import torch_scatter\n',
'import torch_scatter\nfrom flash_attn.modules.mha import MHA\n'
)
# Restore original MHA init (undo nn.MultiheadAttention replacement)
src = src.replace(
' self.attention = nn.MultiheadAttention(embed_dim=feat_dim, num_heads=num_heads, batch_first=True)',
' self.attention = MHA(embed_dim=feat_dim, num_heads=num_heads, cross_attn=True)'
)
# Restore original MHA forward call
src = src.replace(
' attn_output, _ = self.attention(q, kv, kv)',
' attn_output = self.attention(q, x_kv=kv)'
)
with open(path, 'w') as f:
f.write(src)
print('unirig_skin.py restored to use flash_attn MHA (shim will serve it)')
# Verify
import subprocess
r = subprocess.run(['grep', '-n', 'MHA\|flash_attn\|Wq\|in_proj', path],
capture_output=True, text=True)
print(r.stdout)
|