# Restore unirig_skin.py to use flash_attn MHA (will be served by our shim in run.py) # Our earlier patch replaced flash_attn MHA with nn.MultiheadAttention which breaks # checkpoint loading because weight names don't match path = '/root/UniRig/src/model/unirig_skin.py' with open(path) as f: src = f.read() # Restore flash_attn import (remove it if replaced, add it back) if 'from flash_attn.modules.mha import MHA' not in src: # Add after the last import line before class definitions src = src.replace( 'import torch_scatter\n', 'import torch_scatter\nfrom flash_attn.modules.mha import MHA\n' ) # Restore original MHA init (undo nn.MultiheadAttention replacement) src = src.replace( ' self.attention = nn.MultiheadAttention(embed_dim=feat_dim, num_heads=num_heads, batch_first=True)', ' self.attention = MHA(embed_dim=feat_dim, num_heads=num_heads, cross_attn=True)' ) # Restore original MHA forward call src = src.replace( ' attn_output, _ = self.attention(q, kv, kv)', ' attn_output = self.attention(q, x_kv=kv)' ) with open(path, 'w') as f: f.write(src) print('unirig_skin.py restored to use flash_attn MHA (shim will serve it)') # Verify import subprocess r = subprocess.run(['grep', '-n', 'MHA\|flash_attn\|Wq\|in_proj', path], capture_output=True, text=True) print(r.stdout)