Daankular commited on
Commit
b44f5ff
Β·
verified Β·
1 Parent(s): 56b5a11

Upload pipeline/rig_stage.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pipeline/rig_stage.py +1282 -0
pipeline/rig_stage.py ADDED
@@ -0,0 +1,1282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Stage 7 β€” Multi-view pose estimation + mesh rigging
3
+
4
+ Three progressive phases, each feeding the next:
5
+
6
+ Phase 1 (Easy) β€” Multi-view beta averaging
7
+ Run HMR 2.0 on front / 3q_front / side renders + reference photo
8
+ Average shape betas weighted by detection confidence
9
+
10
+ Phase 2 (Better) β€” Silhouette fitting
11
+ Project SMPL mesh orthographically into each of the 5 views
12
+ Optimise betas so the SMPL silhouette matches the TripoSG render mask
13
+ Uses known orthographic camera matrices (exact same params as nvdiffrast)
14
+
15
+ Phase 3 (Best) β€” Multi-view joint triangulation
16
+ For each view where HMR 2.0 fired, project its 2D keypoints back to 3D
17
+ using the known orthographic camera β†’ set up linear system per joint
18
+ Least-squares triangulation gives world-space joint positions used
19
+ directly as the skeleton, overriding the regressed SMPL joints
20
+
21
+ Output: rigged GLB (SMPL 24-joint skeleton + skin weights) + FBX via Blender
22
+ """
23
+
24
+ import os, sys, json, struct, traceback, subprocess, tempfile
25
+ # Must be set before any OpenGL/pyrender import (triggered by hmr2)
26
+ os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
27
+ import numpy as np
28
+
29
+ # ── SMPL constants ────────────────────────────────────────────────────────────
30
+ SMPL_JOINT_NAMES = [
31
+ "pelvis","left_hip","right_hip","spine1",
32
+ "left_knee","right_knee","spine2",
33
+ "left_ankle","right_ankle","spine3",
34
+ "left_foot","right_foot","neck",
35
+ "left_collar","right_collar","head",
36
+ "left_shoulder","right_shoulder",
37
+ "left_elbow","right_elbow",
38
+ "left_wrist","right_wrist",
39
+ "left_hand","right_hand",
40
+ ]
41
+ SMPL_PARENTS = [-1,0,0,0,1,2,3,4,5,6,7,8,9,9,9,
42
+ 12,13,14,16,17,18,19,20,21]
43
+
44
+ # Orthographic camera parameters β€” must match render_views in triposg_app.py
45
+ ORTHO_LEFT, ORTHO_RIGHT = -0.55, 0.55
46
+ ORTHO_BOT, ORTHO_TOP = -0.55, 0.55
47
+ RENDER_W, RENDER_H = 768, 1024
48
+
49
+ # Azimuths passed to get_orthogonal_camera: [x-90 for x in [0,45,90,180,315]]
50
+ VIEW_AZIMUTHS_DEG = [-90.0, -45.0, 0.0, 90.0, 225.0]
51
+ VIEW_NAMES = ["front", "3q_front", "side", "back", "3q_back"]
52
+ VIEW_PATHS = [f"/tmp/render_{n}.png" for n in VIEW_NAMES]
53
+
54
+ # Views with a clearly visible front body (used for Phase 1 beta averaging)
55
+ FRONT_VIEW_INDICES = [0, 1, 2] # front, 3q_front, side
56
+
57
+
58
+ # ══════════════════════════════════════════════════════════════════════════════
59
+ # Camera utilities
60
+ # ══════════════════════════════════════════════════════════════════════════════
61
+
62
+ def _R_y(deg: float) -> np.ndarray:
63
+ """Rotation matrix around Y axis (right-hand, degrees)."""
64
+ t = np.radians(deg)
65
+ c, s = np.cos(t), np.sin(t)
66
+ return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]], dtype=np.float64)
67
+
68
+
69
+ def world_to_cam(pts: np.ndarray, azimuth_deg: float) -> np.ndarray:
70
+ """
71
+ Orthographic projection: world (N,3) β†’ camera (N,2) in world-unit space.
72
+ Convention: camera right = (cos ΞΈ, 0, -sin ΞΈ), up = (0,1,0)
73
+ """
74
+ t = np.radians(azimuth_deg)
75
+ right = np.array([np.cos(t), 0.0, -np.sin(t)])
76
+ up = np.array([0.0, 1.0, 0.0 ])
77
+ return np.stack([pts @ right, pts @ up], axis=-1) # (N, 2)
78
+
79
+
80
+ def cam_to_pixel(cam_xy: np.ndarray) -> np.ndarray:
81
+ """Camera world-unit coords β†’ pixel coords (u, v) in 768Γ—1024 image."""
82
+ u = (cam_xy[:, 0] - ORTHO_LEFT) / (ORTHO_RIGHT - ORTHO_LEFT) * RENDER_W
83
+ v = (ORTHO_TOP - cam_xy[:, 1]) / (ORTHO_TOP - ORTHO_BOT ) * RENDER_H
84
+ return np.stack([u, v], axis=-1)
85
+
86
+
87
+ def pixel_to_cam(uv: np.ndarray) -> np.ndarray:
88
+ """Pixel coords β†’ camera world-unit coords."""
89
+ cx = uv[:, 0] / RENDER_W * (ORTHO_RIGHT - ORTHO_LEFT) + ORTHO_LEFT
90
+ cy = ORTHO_TOP - uv[:, 1] / RENDER_H * (ORTHO_TOP - ORTHO_BOT)
91
+ return np.stack([cx, cy], axis=-1)
92
+
93
+
94
+ def triangulate_joint(obs: list[tuple]) -> np.ndarray:
95
+ """
96
+ Triangulate a single joint from multi-view 2D observations.
97
+ obs: list of (azimuth_deg, pixel_u, pixel_v)
98
+ Returns world (x, y, z).
99
+
100
+ For orthographic cameras, Y is directly measured; X and Z satisfy:
101
+ px*cos(ΞΈ) - pz*sin(ΞΈ) = cx for each view
102
+ β†’ overdetermined linear system solved with lstsq.
103
+ """
104
+ ys, rows_A, rhs = [], [], []
105
+ for az_deg, pu, pv in obs:
106
+ cx, cy = pixel_to_cam(np.array([[pu, pv]]))[0]
107
+ ys.append(cy)
108
+ t = np.radians(az_deg)
109
+ rows_A.append([np.cos(t), -np.sin(t)])
110
+ rhs.append(cx)
111
+
112
+ A = np.array(rows_A, dtype=np.float64)
113
+ b = np.array(rhs, dtype=np.float64)
114
+ wy = float(np.mean(ys))
115
+
116
+ if len(obs) >= 2:
117
+ xz, _, _, _ = np.linalg.lstsq(A, b, rcond=None)
118
+ wx, wz = xz
119
+ else:
120
+ wx, wz = 0.0, 0.0
121
+
122
+ return np.array([wx, wy, wz], dtype=np.float32)
123
+
124
+
125
+ # ══════════════════════════════════════════════════════════════════════════════
126
+ # Phase 1 β€” Multi-view HMR 2.0 + beta averaging
127
+ # ══════════════════════════════════════════════════════════════════════════════
128
+
129
+ def _load_hmr2(device):
130
+ from hmr2.models import download_models, load_hmr2, DEFAULT_CHECKPOINT
131
+ download_models() # downloads to CACHE_DIR_4DHUMANS (no-op if already done)
132
+ model, cfg = load_hmr2(DEFAULT_CHECKPOINT)
133
+ return model.to(device).eval(), cfg
134
+
135
+
136
+ def _load_detector():
137
+ from detectron2.config import LazyConfig
138
+ from hmr2.utils.utils_detectron2 import DefaultPredictor_Lazy
139
+ import hmr2
140
+ cfg = LazyConfig.load(str(os.path.join(
141
+ os.path.dirname(hmr2.__file__),
142
+ "configs/cascade_mask_rcnn_vitdet_h_75ep.py")))
143
+ cfg.train.init_checkpoint = (
144
+ "https://dl.fbaipublicfiles.com/detectron2/ViTDet/COCO/"
145
+ "cascade_mask_rcnn_vitdet_h/f328730692/model_final_f05665.pkl")
146
+ for i in range(3):
147
+ cfg.model.roi_heads.box_predictors[i].test_score_thresh = 0.25
148
+ return DefaultPredictor_Lazy(cfg)
149
+
150
+
151
+ def _run_hmr2_on_image(img_bgr, model, model_cfg, detector, device):
152
+ """
153
+ Run HMR 2.0 on a BGR image. Returns dict or None.
154
+ Keys: betas (10,), body_pose (23,3,3), global_orient (1,3,3),
155
+ kp2d (44,2) in [0,1] normalised, kp3d (44,3), score (float)
156
+ """
157
+ import torch
158
+ from hmr2.utils import recursive_to
159
+ from hmr2.datasets.vitdet_dataset import ViTDetDataset
160
+
161
+ det_out = detector(img_bgr)
162
+ instances = det_out["instances"]
163
+ valid = (instances.pred_classes == 0) & (instances.scores > 0.5)
164
+ if not valid.any():
165
+ return None
166
+
167
+ boxes = instances.pred_boxes.tensor[valid].cpu().numpy()
168
+ score = float(instances.scores[valid].max().cpu())
169
+ best = boxes[np.argmax((boxes[:,2]-boxes[:,0]) * (boxes[:,3]-boxes[:,1]))]
170
+
171
+ ds = ViTDetDataset(model_cfg, img_bgr, [best])
172
+ dl = torch.utils.data.DataLoader(ds, batch_size=1, shuffle=False)
173
+ batch = recursive_to(next(iter(dl)), device)
174
+
175
+ with torch.no_grad():
176
+ out = model(batch)
177
+
178
+ p = out["pred_smpl_params"]
179
+ return {
180
+ "betas": p["betas"][0].cpu().numpy(),
181
+ "body_pose": p["body_pose"][0].cpu().numpy(),
182
+ "global_orient": p["global_orient"][0].cpu().numpy(),
183
+ "kp2d": out["pred_keypoints_2d"][0].cpu().numpy(), # (44,2) [-1,1]
184
+ "kp3d": out.get("pred_keypoints_3d", [None]*1)[0],
185
+ "score": score,
186
+ "detected": True,
187
+ }
188
+
189
+
190
+ def estimate_betas_multiview(view_paths: list[str],
191
+ ref_path: str,
192
+ device: str = "cuda") -> tuple[np.ndarray, list]:
193
+ """
194
+ Phase 1: run HMR 2.0 on reference photo + front/3q/side renders.
195
+ Returns (averaged_betas [10,], list_of_all_results).
196
+ Falls back to zero betas (average body shape) if HMR2 is unavailable.
197
+ """
198
+ import cv2
199
+ print("[rig P1] Loading HMR2 + detector...")
200
+ try:
201
+ model, model_cfg = _load_hmr2(device)
202
+ detector = _load_detector()
203
+ except Exception as e:
204
+ print(f"[rig P1] HMR2 unavailable ({e}) β€” using zero betas (average body shape)")
205
+ return np.zeros(10, dtype=np.float32), []
206
+
207
+ sources = [(ref_path, None)] # (path, azimuth_deg_or_None)
208
+ for idx in FRONT_VIEW_INDICES:
209
+ if idx < len(view_paths) and os.path.exists(view_paths[idx]):
210
+ sources.append((view_paths[idx], VIEW_AZIMUTHS_DEG[idx]))
211
+
212
+ results = []
213
+ weighted_betas, total_w = np.zeros(10, dtype=np.float64), 0.0
214
+
215
+ for path, az in sources:
216
+ img = cv2.imread(path)
217
+ if img is None:
218
+ continue
219
+ r = _run_hmr2_on_image(img, model, model_cfg, detector, device)
220
+ if r is None:
221
+ print(f"[rig P1] {os.path.basename(path)}: no person detected")
222
+ continue
223
+ r["azimuth_deg"] = az
224
+ r["path"] = path
225
+ results.append(r)
226
+ w = r["score"]
227
+ weighted_betas += r["betas"] * w
228
+ total_w += w
229
+ print(f"[rig P1] {os.path.basename(path)}: detected (score={w:.2f}), "
230
+ f"betas[:3]={r['betas'][:3]}")
231
+
232
+ avg_betas = (weighted_betas / total_w).astype(np.float32) if total_w > 0 \
233
+ else np.zeros(10, dtype=np.float32)
234
+ print(f"[rig P1] Averaged betas over {len(results)} detections.")
235
+ return avg_betas, results
236
+
237
+
238
+ # ═════════════════════════════════════════════════��════════════════════════════
239
+ # SMPL helpers
240
+ # ══════════════════════════════════════════════════════════════════════════════
241
+
242
+ def get_smpl_tpose(betas: np.ndarray, smpl_dir: str = "/root/smpl_models"):
243
+ """Returns (verts [N,3], faces [M,3], joints [24,3], lbs_weights [N,24]).
244
+ Uses smplx if SMPL_NEUTRAL.pkl is available, else falls back to a synthetic
245
+ proxy skeleton with proximity-based skinning weights."""
246
+ import torch
247
+
248
+ model_path = os.path.join(smpl_dir, "SMPL_NEUTRAL.pkl")
249
+ if not os.path.exists(model_path) or os.path.getsize(model_path) < 1000:
250
+ # Try download first, silently fall through to synthetic on failure
251
+ try:
252
+ _download_smpl_neutral(smpl_dir)
253
+ except Exception:
254
+ pass
255
+
256
+ if os.path.exists(model_path) and os.path.getsize(model_path) > 100_000:
257
+ import smplx
258
+ smpl = smplx.create(smpl_dir, model_type="smpl", gender="neutral", num_betas=10)
259
+ betas_t = torch.tensor(betas[:10], dtype=torch.float32).unsqueeze(0)
260
+ with torch.no_grad():
261
+ out = smpl(betas=betas_t, return_verts=True)
262
+ verts = out.vertices[0].numpy().astype(np.float32)
263
+ joints = out.joints[0, :24].numpy().astype(np.float32)
264
+ faces = smpl.faces.astype(np.int32)
265
+ weights = smpl.lbs_weights.numpy().astype(np.float32)
266
+ return verts, faces, joints, weights
267
+
268
+ print("[rig] SMPL_NEUTRAL.pkl unavailable β€” using synthetic proxy skeleton")
269
+ return _synthetic_smpl_tpose()
270
+
271
+
272
+ def _synthetic_smpl_tpose():
273
+ """Synthetic SMPL substitute: hardcoded T-pose joint positions + proximity weights.
274
+ Gives a rough but functional rig for pipeline testing when SMPL is unavailable.
275
+ For production, provide SMPL_NEUTRAL.pkl from https://smpl.is.tue.mpg.de/."""
276
+ # 24 SMPL T-pose joint positions (metres, Y-up, facing +Z)
277
+ joints = np.array([
278
+ [ 0.00, 0.92, 0.00], # 0 pelvis
279
+ [-0.09, 0.86, 0.00], # 1 left_hip
280
+ [ 0.09, 0.86, 0.00], # 2 right_hip
281
+ [ 0.00, 1.05, 0.00], # 3 spine1
282
+ [-0.09, 0.52, 0.00], # 4 left_knee
283
+ [ 0.09, 0.52, 0.00], # 5 right_knee
284
+ [ 0.00, 1.17, 0.00], # 6 spine2
285
+ [-0.09, 0.10, 0.00], # 7 left_ankle
286
+ [ 0.09, 0.10, 0.00], # 8 right_ankle
287
+ [ 0.00, 1.29, 0.00], # 9 spine3
288
+ [-0.09, 0.00, 0.07], # 10 left_foot
289
+ [ 0.09, 0.00, 0.07], # 11 right_foot
290
+ [ 0.00, 1.46, 0.00], # 12 neck
291
+ [-0.07, 1.42, 0.00], # 13 left_collar
292
+ [ 0.07, 1.42, 0.00], # 14 right_collar
293
+ [ 0.00, 1.62, 0.00], # 15 head
294
+ [-0.17, 1.40, 0.00], # 16 left_shoulder
295
+ [ 0.17, 1.40, 0.00], # 17 right_shoulder
296
+ [-0.42, 1.40, 0.00], # 18 left_elbow
297
+ [ 0.42, 1.40, 0.00], # 19 right_elbow
298
+ [-0.65, 1.40, 0.00], # 20 left_wrist
299
+ [ 0.65, 1.40, 0.00], # 21 right_wrist
300
+ [-0.72, 1.40, 0.00], # 22 left_hand
301
+ [ 0.72, 1.40, 0.00], # 23 right_hand
302
+ ], dtype=np.float32)
303
+
304
+ # Build synthetic proxy vertices: ~300 points clustered around each joint
305
+ rng = np.random.default_rng(42)
306
+ n_per_joint = 300
307
+ proxy_v = []
308
+ proxy_w = []
309
+ for ji, jpos in enumerate(joints):
310
+ pts = jpos + rng.normal(0, 0.06, (n_per_joint, 3)).astype(np.float32)
311
+ proxy_v.append(pts)
312
+ w = np.zeros((n_per_joint, 24), np.float32)
313
+ w[:, ji] = 1.0
314
+ proxy_w.append(w)
315
+
316
+ proxy_v = np.concatenate(proxy_v, axis=0) # (7200, 3)
317
+ proxy_w = np.concatenate(proxy_w, axis=0) # (7200, 24)
318
+ proxy_f = np.zeros((0, 3), dtype=np.int32) # no faces needed for KNN transfer
319
+ return proxy_v, proxy_f, joints, proxy_w
320
+
321
+
322
+ def _download_smpl_neutral(out_dir: str):
323
+ os.makedirs(out_dir, exist_ok=True)
324
+ url = ("https://huggingface.co/spaces/TMElyralab/MusePose/resolve/main"
325
+ "/models/smpl/SMPL_NEUTRAL.pkl")
326
+ dest = os.path.join(out_dir, "SMPL_NEUTRAL.pkl")
327
+ print("[rig] Downloading SMPL_NEUTRAL.pkl...")
328
+ subprocess.run(["wget", "-q", url, "-O", dest], check=True)
329
+
330
+
331
+ def _smpl_to_render_space(verts: np.ndarray, joints: np.ndarray):
332
+ """
333
+ Normalise SMPL vertices to fit inside the [-0.55, 0.55] orthographic
334
+ frustum used by the nvdiffrast renders (same as align_mesh_to_smpl).
335
+ Returns (verts_norm, joints_norm, scale, offset).
336
+ """
337
+ ymin, ymax = verts[:, 1].min(), verts[:, 1].max()
338
+ height = ymax - ymin
339
+ scale = (ORTHO_TOP - ORTHO_BOT) / max(height, 1e-6)
340
+
341
+ # Centre on pelvis (joint 0) horizontally, floor-align vertically
342
+ v = verts * scale
343
+ j = joints * scale
344
+ cx = (v[:, 0].max() + v[:, 0].min()) * 0.5
345
+ cz = (v[:, 2].max() + v[:, 2].min()) * 0.5
346
+ v[:, 0] -= cx; j[:, 0] -= cx
347
+ v[:, 2] -= cz; j[:, 2] -= cz
348
+ v[:, 1] -= v[:, 1].min() + ORTHO_BOT # floor at ORTHO_BOT
349
+ j[:, 1] -= (verts[:, 1].min() * scale) - ORTHO_BOT
350
+ return v, j, scale, np.array([-cx, -v[:,1].min() + ORTHO_BOT, -cz])
351
+
352
+
353
+ # ══════════════════════════════════════════════════════════════════════════════
354
+ # Phase 2 β€” Silhouette fitting
355
+ # ══════════════════════════════════════════════════════════════════════════════
356
+
357
+ def _extract_silhouette(render_path: str, threshold: int = 20) -> np.ndarray:
358
+ """Binary mask (HΓ—W bool) from a render: foreground = any channel > threshold."""
359
+ import cv2
360
+ img = cv2.imread(render_path)
361
+ if img is None:
362
+ return np.zeros((RENDER_H, RENDER_W), dtype=bool)
363
+ return img.max(axis=2) > threshold
364
+
365
+
366
+ def _render_smpl_silhouette(verts_norm: np.ndarray, faces: np.ndarray,
367
+ azimuth_deg: float) -> np.ndarray:
368
+ """
369
+ Rasterise SMPL mesh silhouette for given azimuth (orthographic).
370
+ Returns binary mask (HΓ—W bool).
371
+ """
372
+ from PIL import Image, ImageDraw
373
+
374
+ cam_xy = world_to_cam(verts_norm, azimuth_deg)
375
+ pix = cam_to_pixel(cam_xy) # (N, 2)
376
+
377
+ img = Image.new("L", (RENDER_W, RENDER_H), 0)
378
+ draw = ImageDraw.Draw(img)
379
+ for f in faces:
380
+ pts = [(float(pix[i, 0]), float(pix[i, 1])) for i in f]
381
+ draw.polygon(pts, fill=255)
382
+ return np.array(img) > 0
383
+
384
+
385
+ def _sil_loss(betas: np.ndarray, target_masks: list,
386
+ valid_views: list[int], faces: np.ndarray) -> float:
387
+ """1 - mean IoU between SMPL silhouettes and TripoSG render masks."""
388
+ try:
389
+ verts, _, _, _ = get_smpl_tpose(betas.astype(np.float32))
390
+ verts_n, _, _, _ = _smpl_to_render_space(verts, verts.copy())
391
+ iou_sum = 0.0
392
+ for i in valid_views:
393
+ pred = _render_smpl_silhouette(verts_n, faces, VIEW_AZIMUTHS_DEG[i])
394
+ tgt = target_masks[i]
395
+ inter = (pred & tgt).sum()
396
+ union = (pred | tgt).sum()
397
+ iou_sum += inter / max(union, 1)
398
+ return 1.0 - iou_sum / len(valid_views)
399
+ except Exception:
400
+ return 1.0
401
+
402
+
403
+ def fit_betas_silhouette(betas_init: np.ndarray, view_paths: list[str],
404
+ max_iter: int = 60) -> np.ndarray:
405
+ """
406
+ Phase 2: optimise SMPL betas to match TripoSG render silhouettes.
407
+ Only uses views whose render file exists.
408
+ """
409
+ from scipy.optimize import minimize
410
+
411
+ valid = [i for i, p in enumerate(view_paths) if os.path.exists(p)]
412
+ if not valid:
413
+ print("[rig P2] No render files found β€” skipping silhouette fit")
414
+ return betas_init
415
+
416
+ print(f"[rig P2] Extracting silhouettes from {len(valid)} views...")
417
+ masks = [_extract_silhouette(view_paths[i]) if i in valid
418
+ else np.zeros((RENDER_H, RENDER_W), bool)
419
+ for i in range(len(VIEW_NAMES))]
420
+
421
+ # Use only back-facing views for shape, not back (which shows less shape info)
422
+ fit_views = [i for i in valid if i in [0, 1, 2]]
423
+ if not fit_views:
424
+ fit_views = valid
425
+
426
+ # Pre-fetch faces (constant across iterations)
427
+ verts0, faces0, _, _ = get_smpl_tpose(betas_init)
428
+
429
+ loss0 = _sil_loss(betas_init, masks, fit_views, faces0)
430
+ print(f"[rig P2] Initial silhouette loss: {loss0:.4f}")
431
+
432
+ result = minimize(
433
+ fun=lambda b: _sil_loss(b, masks, fit_views, faces0),
434
+ x0=betas_init.astype(np.float64),
435
+ method="L-BFGS-B",
436
+ bounds=[(-3.0, 3.0)] * 10,
437
+ options={"maxiter": max_iter, "ftol": 1e-4, "gtol": 1e-3},
438
+ )
439
+
440
+ refined = result.x.astype(np.float32)
441
+ loss1 = _sil_loss(refined, masks, fit_views, faces0)
442
+ print(f"[rig P2] Silhouette fit done: loss {loss0:.4f} β†’ {loss1:.4f} "
443
+ f"({result.nit} iters, {'converged' if result.success else 'stopped'})")
444
+ return refined
445
+
446
+
447
+ # ══════════════════════════════════════════════════════════════════════════════
448
+ # Phase 3 β€” Multi-view joint triangulation
449
+ # ══════════════════════════════════════════════════════════════════════════════
450
+
451
+ # HMR 2.0 outputs 44 keypoints; first 24 map to SMPL joints
452
+ HMR2_TO_SMPL = list(range(24))
453
+
454
+ def triangulate_joints_multiview(hmr2_results: list) -> np.ndarray | None:
455
+ """
456
+ Phase 3: triangulate world-space SMPL joints from multi-view HMR 2.0 2D keypoints.
457
+
458
+ hmr2_results: list of dicts from _run_hmr2_on_image, each with
459
+ kp2d (44,2) in [-1,1] normalised NDC and azimuth_deg (float or None).
460
+
461
+ Only uses results from rendered views (azimuth_deg is not None).
462
+ Returns (24,3) world joint positions, or None if < 2 valid views.
463
+ """
464
+ view_results = [r for r in hmr2_results
465
+ if r.get("azimuth_deg") is not None and r.get("kp2d") is not None]
466
+
467
+ if len(view_results) < 2:
468
+ print(f"[rig P3] Only {len(view_results)} render views with detections "
469
+ "β€” need β‰₯2 for triangulation, skipping")
470
+ return None
471
+
472
+ print(f"[rig P3] Triangulating from {len(view_results)} views: "
473
+ + ", ".join(os.path.basename(r["path"]) for r in view_results))
474
+
475
+ # Convert HMR2 NDC keypoints β†’ pixel coords
476
+ # kp2d is (44,2) in [-1,1]; pixel = (kp+1)/2 * [W, H]
477
+ joints_world = np.zeros((24, 3), dtype=np.float32)
478
+
479
+ for j in range(24):
480
+ obs = []
481
+ for r in view_results:
482
+ kp = r["kp2d"][j] # (2,) in [-1,1]
483
+ pu = (kp[0] + 1.0) / 2.0 * RENDER_W
484
+ pv = (kp[1] + 1.0) / 2.0 * RENDER_H
485
+ obs.append((r["azimuth_deg"], pu, pv))
486
+ joints_world[j] = triangulate_joint(obs)
487
+
488
+ print(f"[rig P3] Triangulated 24 joints. "
489
+ f"Pelvis: {joints_world[0].round(3)}, "
490
+ f"Head: {joints_world[15].round(3)}")
491
+ return joints_world
492
+
493
+
494
+ # ══════════════════════════════════════════════════════════════════════════════
495
+ # Skinning weight transfer
496
+ # ══════════════════════════════════════════════════════════════════════════════
497
+
498
+ def transfer_skinning(smpl_verts: np.ndarray, smpl_weights: np.ndarray,
499
+ target_verts: np.ndarray, k: int = 4) -> np.ndarray:
500
+ from scipy.spatial import cKDTree
501
+ tree = cKDTree(smpl_verts)
502
+ dists, idxs = tree.query(target_verts, k=k, workers=-1)
503
+ dists = np.maximum(dists, 1e-8)
504
+ inv_d = 1.0 / dists
505
+ inv_d /= inv_d.sum(axis=1, keepdims=True)
506
+ transferred = np.einsum("nk,nkj->nj", inv_d, smpl_weights[idxs])
507
+ row_sums = transferred.sum(axis=1, keepdims=True)
508
+ transferred /= np.where(row_sums > 0, row_sums, 1.0)
509
+ return transferred.astype(np.float32)
510
+
511
+
512
+ def align_mesh_to_smpl(mesh_verts: np.ndarray, smpl_verts: np.ndarray,
513
+ smpl_joints: np.ndarray) -> np.ndarray:
514
+ smpl_h = smpl_verts[:, 1].max() - smpl_verts[:, 1].min()
515
+ mesh_h = mesh_verts[:, 1].max() - mesh_verts[:, 1].min()
516
+ scale = smpl_h / max(mesh_h, 1e-6)
517
+ v = mesh_verts * scale
518
+ cx = (v[:, 0].max() + v[:, 0].min()) * 0.5
519
+ cz = (v[:, 2].max() + v[:, 2].min()) * 0.5
520
+ v[:, 0] += smpl_joints[0, 0] - cx
521
+ v[:, 2] += smpl_joints[0, 2] - cz
522
+ v[:, 1] -= v[:, 1].min()
523
+ return v
524
+
525
+
526
+ # ══════════════════════════════════════════════════════════════════════════════
527
+ # GLB export
528
+ # ══════════════════════════════════════════════════════════════════════════════
529
+
530
+ def export_rigged_glb(verts, faces, uv, texture_img, joints, skin_weights, out_path):
531
+ import pygltflib
532
+ from pygltflib import (GLTF2, Scene, Node, Mesh, Primitive, Accessor,
533
+ BufferView, Buffer, Material, Texture,
534
+ Image as GImage, Sampler, Skin, Asset)
535
+ from pygltflib import (ARRAY_BUFFER, ELEMENT_ARRAY_BUFFER, FLOAT,
536
+ UNSIGNED_INT, UNSIGNED_SHORT, LINEAR,
537
+ LINEAR_MIPMAP_LINEAR, REPEAT, SCALAR, VEC2,
538
+ VEC3, VEC4, MAT4)
539
+
540
+ gltf = GLTF2()
541
+ gltf.asset = Asset(version="2.0", generator="rig_stage.py")
542
+ blobs = []
543
+
544
+ def _add(data: np.ndarray, comp, acc_type, target=None):
545
+ b = data.tobytes()
546
+ pad = (4 - len(b) % 4) % 4
547
+ off = sum(len(x) for x in blobs)
548
+ blobs.append(b + b"\x00" * pad)
549
+ bv = len(gltf.bufferViews)
550
+ gltf.bufferViews.append(BufferView(buffer=0, byteOffset=off,
551
+ byteLength=len(b), target=target))
552
+ ac = len(gltf.accessors)
553
+ flat = data.flatten()
554
+ gltf.accessors.append(Accessor(
555
+ bufferView=bv, byteOffset=0, componentType=comp,
556
+ type=acc_type, count=len(data),
557
+ min=[float(flat.min())], max=[float(flat.max())]))
558
+ return ac
559
+
560
+ pos_acc = _add(verts.astype(np.float32), FLOAT, VEC3, ARRAY_BUFFER)
561
+
562
+ v0,v1,v2 = verts[faces[:,0]], verts[faces[:,1]], verts[faces[:,2]]
563
+ fn = np.cross(v1-v0, v2-v0); fn /= (np.linalg.norm(fn,axis=1,keepdims=True)+1e-8)
564
+ vn = np.zeros_like(verts)
565
+ for i in range(3): np.add.at(vn, faces[:,i], fn)
566
+ vn /= (np.linalg.norm(vn,axis=1,keepdims=True)+1e-8)
567
+ nor_acc = _add(vn.astype(np.float32), FLOAT, VEC3, ARRAY_BUFFER)
568
+
569
+ if uv is None: uv = np.zeros((len(verts),2), np.float32)
570
+ uv_acc = _add(uv.astype(np.float32), FLOAT, VEC2, ARRAY_BUFFER)
571
+ idx_acc = _add(faces.astype(np.uint32).flatten(), UNSIGNED_INT, SCALAR, ELEMENT_ARRAY_BUFFER)
572
+
573
+ top4_idx = np.argsort(-skin_weights, axis=1)[:,:4].astype(np.uint16)
574
+ top4_w = np.take_along_axis(skin_weights, top4_idx.astype(np.int64), axis=1).astype(np.float32)
575
+ top4_w /= top4_w.sum(axis=1,keepdims=True).clip(1e-8,None)
576
+ j_acc = _add(top4_idx, UNSIGNED_SHORT, "VEC4", ARRAY_BUFFER)
577
+ w_acc = _add(top4_w, FLOAT, "VEC4", ARRAY_BUFFER)
578
+
579
+ if texture_img is not None:
580
+ import io
581
+ buf = io.BytesIO(); texture_img.save(buf, format="PNG"); ib = buf.getvalue()
582
+ off = sum(len(x) for x in blobs); pad = (4-len(ib)%4)%4
583
+ blobs.append(ib + b"\x00"*pad)
584
+ gltf.bufferViews.append(BufferView(buffer=0,byteOffset=off,byteLength=len(ib)))
585
+ gltf.images.append(GImage(mimeType="image/png",bufferView=len(gltf.bufferViews)-1))
586
+ gltf.samplers.append(Sampler(magFilter=LINEAR,minFilter=LINEAR_MIPMAP_LINEAR,
587
+ wrapS=REPEAT,wrapT=REPEAT))
588
+ gltf.textures.append(Texture(sampler=0,source=0))
589
+ gltf.materials.append(Material(name="body",
590
+ pbrMetallicRoughness={"baseColorTexture":{"index":0},
591
+ "metallicFactor":0.0,"roughnessFactor":0.8},
592
+ doubleSided=True))
593
+ else:
594
+ gltf.materials.append(Material(name="body",doubleSided=True))
595
+
596
+ prim = Primitive(attributes={"POSITION":pos_acc,"NORMAL":nor_acc,
597
+ "TEXCOORD_0":uv_acc,"JOINTS_0":j_acc,"WEIGHTS_0":w_acc},
598
+ indices=idx_acc, material=0)
599
+ gltf.meshes.append(Mesh(name="body",primitives=[prim]))
600
+
601
+ jnodes = []
602
+ for i,(name,parent) in enumerate(zip(SMPL_JOINT_NAMES,SMPL_PARENTS)):
603
+ t = joints[i].tolist() if parent==-1 else (joints[i]-joints[parent]).tolist()
604
+ n = Node(name=name,translation=t,children=[])
605
+ jnodes.append(len(gltf.nodes)); gltf.nodes.append(n)
606
+ for i,p in enumerate(SMPL_PARENTS):
607
+ if p!=-1: gltf.nodes[jnodes[p]].children.append(jnodes[i])
608
+
609
+ ibms = np.stack([np.eye(4,dtype=np.float32) for _ in range(len(joints))])
610
+ for i in range(len(joints)): ibms[i,:3,3] = -joints[i]
611
+ ibm_acc = _add(ibms.astype(np.float32), FLOAT, MAT4)
612
+ skin_idx = len(gltf.skins)
613
+ gltf.skins.append(Skin(name="smpl_skin",skeleton=jnodes[0],
614
+ joints=jnodes,inverseBindMatrices=ibm_acc))
615
+
616
+ mesh_node = len(gltf.nodes)
617
+ gltf.nodes.append(Node(name="body_mesh",mesh=0,skin=skin_idx))
618
+ root_node = len(gltf.nodes)
619
+ gltf.nodes.append(Node(name="root",children=[jnodes[0],mesh_node]))
620
+ gltf.scenes.append(Scene(name="Scene",nodes=[root_node]))
621
+ gltf.scene = 0
622
+
623
+ bin_data = b"".join(blobs)
624
+ gltf.buffers.append(Buffer(byteLength=len(bin_data)))
625
+ gltf.set_binary_blob(bin_data)
626
+ gltf.save_binary(out_path)
627
+ print(f"[rig] Rigged GLB β†’ {out_path} ({os.path.getsize(out_path)//1024} KB)")
628
+
629
+
630
+ # ══════════════════════════════════════════════════════════════════════════════
631
+ # FBX export via Blender headless
632
+ # ══════════════════════════════════════════════════════════════════════════════
633
+
634
+ _BLENDER_SCRIPT = """\
635
+ import bpy, sys
636
+ args = sys.argv[sys.argv.index('--') + 1:]
637
+ glb_in, fbx_out = args[0], args[1]
638
+ bpy.ops.wm.read_factory_settings(use_empty=True)
639
+ bpy.ops.import_scene.gltf(filepath=glb_in)
640
+ bpy.ops.export_scene.fbx(
641
+ filepath=fbx_out, use_selection=False,
642
+ add_leaf_bones=False, bake_anim=False,
643
+ path_mode='COPY', embed_textures=True,
644
+ )
645
+ print('FBX OK:', fbx_out)
646
+ """
647
+
648
+ def export_fbx(rigged_glb: str, out_path: str) -> bool:
649
+ blender = next((c for c in ["/usr/bin/blender","/usr/local/bin/blender"]
650
+ if os.path.exists(c)), None)
651
+ if blender is None:
652
+ r = subprocess.run(["which","blender"],capture_output=True,text=True)
653
+ blender = r.stdout.strip() or None
654
+ if blender is None:
655
+ print("[rig] Blender not found β€” skipping FBX")
656
+ return False
657
+ try:
658
+ with tempfile.NamedTemporaryFile("w",suffix=".py",delete=False) as f:
659
+ f.write(_BLENDER_SCRIPT); script = f.name
660
+ r = subprocess.run([blender,"--background","--python",script,
661
+ "--",rigged_glb,out_path],
662
+ capture_output=True,text=True,timeout=120)
663
+ ok = os.path.exists(out_path)
664
+ if not ok: print(f"[rig] Blender stderr:\n{r.stderr[-800:]}")
665
+ return ok
666
+ except Exception:
667
+ print(f"[rig] export_fbx:\n{traceback.format_exc()}")
668
+ return False
669
+ finally:
670
+ try: os.unlink(script)
671
+ except: pass
672
+
673
+
674
+ # ══════════════════════════════════════════════════════════════════════════════
675
+ # MDM β€” Motion Diffusion Model
676
+ # ══════════════════════════════════════════════════════════════════════════════
677
+
678
+ MDM_DIR = "/root/MDM"
679
+ MDM_CKPT = f"{MDM_DIR}/save/humanml_trans_enc_512/model000200000.pt"
680
+
681
+ # HumanML3D 22-joint parent array (matches SMPL joints 0-21)
682
+ _MDM_PARENTS = [-1,0,0,0,1,2,3,4,5,6,7,8,9,9,9,12,13,14,16,17,18,19]
683
+
684
+ def setup_mdm() -> bool:
685
+ """Clone MDM repo, install deps, download checkpoint. Idempotent."""
686
+ if os.path.exists(MDM_CKPT):
687
+ return True
688
+ print("[MDM] First-time setup...")
689
+
690
+ if not os.path.exists(MDM_DIR):
691
+ r = subprocess.run(
692
+ ["git", "clone", "--depth=1",
693
+ "https://github.com/GuyTevet/motion-diffusion-model.git", MDM_DIR],
694
+ capture_output=True, text=True, timeout=120)
695
+ if r.returncode != 0:
696
+ print(f"[MDM] git clone failed:\n{r.stderr}")
697
+ return False
698
+
699
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q",
700
+ "git+https://github.com/openai/CLIP.git",
701
+ "einops", "rotary-embedding-torch", "gdown"], check=False, timeout=300)
702
+
703
+ # HumanML3D normalisation stats (small .npy files needed for inference)
704
+ stats_dir = f"{MDM_DIR}/dataset/HumanML3D"
705
+ os.makedirs(stats_dir, exist_ok=True)
706
+ base = "https://github.com/EricGuo5513/HumanML3D/raw/main/HumanML3D"
707
+ for fn in ["Mean.npy", "Std.npy"]:
708
+ dest = f"{stats_dir}/{fn}"
709
+ if not os.path.exists(dest):
710
+ subprocess.run(["wget", "-q", f"{base}/{fn}", "-O", dest],
711
+ check=False, timeout=60)
712
+
713
+ # Checkpoint (~1.3 GB) β€” try HuggingFace mirror first, then gdown
714
+ ckpt_dir = os.path.dirname(MDM_CKPT)
715
+ os.makedirs(ckpt_dir, exist_ok=True)
716
+ hf = ("https://huggingface.co/Mathux/motion-diffusion-model/resolve/main/"
717
+ "humanml_trans_enc_512/model000200000.pt")
718
+ r = subprocess.run(["wget", "-q", "--show-progress", hf, "-O", MDM_CKPT],
719
+ capture_output=True, timeout=3600)
720
+ if r.returncode != 0 or not os.path.exists(MDM_CKPT) or \
721
+ os.path.getsize(MDM_CKPT) < 10_000_000:
722
+ print("[MDM] HF download failed β€” trying gdown (official Google Drive)...")
723
+ subprocess.run([sys.executable, "-m", "gdown",
724
+ "--id", "1PE0PK8e5a5j-7-Xhs5YET5U5pGh0c821",
725
+ "-O", MDM_CKPT], check=False, timeout=3600)
726
+
727
+ ok = os.path.exists(MDM_CKPT) and os.path.getsize(MDM_CKPT) > 10_000_000
728
+ print(f"[MDM] Setup {'OK' if ok else 'FAILED'}")
729
+ return ok
730
+
731
+
732
+ def generate_motion_mdm(text_prompt: str, n_frames: int = 120,
733
+ fps: int = 20, device: str = "cuda") -> dict | None:
734
+ """
735
+ Run MDM text-to-motion. Returns {'positions': (n_frames,22,3), 'fps': fps}
736
+ or None on failure. First call runs setup_mdm() which may take ~10 min.
737
+ """
738
+ if not setup_mdm():
739
+ return None
740
+
741
+ out_dir = tempfile.mkdtemp(prefix="mdm_")
742
+ motion_len = round(n_frames / fps, 2)
743
+
744
+ # Minimal inline driver β€” avoids MDM's argparse setup entirely
745
+ driver_src = f"""
746
+ import sys, os
747
+ sys.path.insert(0, {repr(MDM_DIR)})
748
+ os.chdir({repr(MDM_DIR)})
749
+ import numpy as np, torch
750
+
751
+ from utils.fixseed import fixseed
752
+ from utils.model_util import create_model_and_diffusion
753
+ from utils import dist_util
754
+ from data_loaders.humanml.utils.paramUtil import t2m_kinematic_chain
755
+ from data_loaders.humanml.scripts.motion_process import recover_from_ric
756
+ import clip as clip_lib
757
+
758
+ fixseed(42)
759
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
760
+ dist_util.dev = lambda: device
761
+
762
+ import argparse
763
+ args = argparse.Namespace(
764
+ arch='trans_enc', emb_trans_dec=False,
765
+ layers=8, latent_dim=512, ff_size=1024, num_heads=4,
766
+ dropout=0.1, activation='gelu', data_rep='rot6d',
767
+ dataset='humanml', cond_mode='text', cond_mask_prob=0.1,
768
+ lambda_rcxyz=0, lambda_vel=0, lambda_fc=0,
769
+ njoints=263, nfeats=1,
770
+ num_actions=1, translation=True, pose_rep='rot6d',
771
+ glob=True, glob_rot=True, npose=315,
772
+ device=0, seed=42, batch_size=1, num_samples=1,
773
+ num_repetitions=1, motion_length={motion_len!r},
774
+ input_text='', text_prompt='', action_file='', action_name='',
775
+ output_dir={repr(out_dir)}, guidance_param=2.5,
776
+ unconstrained=False,
777
+ # additional args required by get_model_args / create_gaussian_diffusion
778
+ text_encoder_type='clip',
779
+ pos_embed_max_len=5000,
780
+ mask_frames=False,
781
+ pred_len=0,
782
+ context_len=0,
783
+ diffusion_steps=1000,
784
+ noise_schedule='cosine',
785
+ sigma_small=True,
786
+ lambda_target_loc=0,
787
+ )
788
+
789
+ class _MockData:
790
+ class dataset:
791
+ pass
792
+ model, diffusion = create_model_and_diffusion(args, _MockData())
793
+ state = torch.load({repr(MDM_CKPT)}, map_location='cpu', weights_only=False)
794
+ missing, unexpected = model.load_state_dict(state, strict=False)
795
+ model.eval().to(device)
796
+
797
+ max_frames = int({n_frames})
798
+ shape = (1, model.njoints, model.nfeats, max_frames)
799
+ clip_model, _ = clip_lib.load('ViT-B/32', device=device, jit=False)
800
+ clip_model.eval()
801
+ tokens = clip_lib.tokenize([{repr(text_prompt)}]).to(device)
802
+ with torch.no_grad():
803
+ text_emb = clip_model.encode_text(tokens).float()
804
+
805
+ model_kwargs = {{
806
+ 'y': {{
807
+ 'mask': torch.ones(1, 1, 1, max_frames).to(device),
808
+ 'lengths': torch.tensor([max_frames]).to(device),
809
+ 'text': [{repr(text_prompt)}],
810
+ 'tokens': [''],
811
+ 'scale': torch.ones(1).to(device) * 2.5,
812
+ }}
813
+ }}
814
+
815
+ with torch.no_grad():
816
+ sample = diffusion.p_sample_loop(
817
+ model, shape, clip_denoised=False,
818
+ model_kwargs=model_kwargs, skip_timesteps=0,
819
+ init_image=None, progress=False, dump_steps=None,
820
+ noise=None, const_noise=False,
821
+ ) # (1, 263, 1, n_frames)
822
+
823
+ # Convert HumanML3D features β†’ joint XYZ using recover_from_ric (no SMPL needed)
824
+ # sample: (1, 263, 1, n_frames) β†’ (1, n_frames, 263)
825
+ sample_ric = sample[:, :, 0, :].permute(0, 2, 1)
826
+ xyz = recover_from_ric(sample_ric, 22) # (1, n_frames, 22, 3)
827
+ positions = xyz[0].cpu().numpy() # (n_frames, 22, 3)
828
+ np.save(os.path.join({repr(out_dir)}, 'positions.npy'), positions)
829
+ print('MDM_DONE')
830
+ """
831
+ driver_f = None
832
+ try:
833
+ with tempfile.NamedTemporaryFile('w', suffix='.py', delete=False) as f:
834
+ f.write(driver_src)
835
+ driver_f = f.name
836
+
837
+ r = subprocess.run(
838
+ [sys.executable, driver_f],
839
+ capture_output=True, text=True, timeout=600,
840
+ env={**os.environ, "PYTHONPATH": MDM_DIR, "CUDA_VISIBLE_DEVICES": "0"},
841
+ )
842
+ print(f"[MDM] stdout: {r.stdout[-400:]}")
843
+ if r.returncode != 0:
844
+ print(f"[MDM] FAILED:\n{r.stderr[-600:]}")
845
+ return None
846
+
847
+ npy = os.path.join(out_dir, "positions.npy")
848
+ if not os.path.exists(npy):
849
+ print("[MDM] positions.npy not found")
850
+ return None
851
+
852
+ arr = np.load(npy) # (n_frames, 22, 3)
853
+ positions = arr # already (n_frames, 22, 3)
854
+ print(f"[MDM] Motion: {positions.shape}, fps={fps}")
855
+ return {"positions": positions, "fps": fps, "n_frames": positions.shape[0]}
856
+
857
+ except Exception:
858
+ print(f"[MDM] Exception:\n{traceback.format_exc()}")
859
+ return None
860
+ finally:
861
+ if driver_f:
862
+ try: os.unlink(driver_f)
863
+ except: pass
864
+
865
+
866
+ # ══════════════════════════════════════════════════════════════════════════════
867
+ # FK Inversion β€” joint world-positions β†’ local quaternions per frame
868
+ # ══════════════════════════════════════════════════════════════════════════════
869
+
870
+ def _quat_between(v0: np.ndarray, v1: np.ndarray) -> np.ndarray:
871
+ """Shortest-arc quaternion [x,y,z,w] that rotates unit vector v0 β†’ v1."""
872
+ cross = np.cross(v0, v1)
873
+ dot = float(np.clip(np.dot(v0, v1), -1.0, 1.0))
874
+ cn = np.linalg.norm(cross)
875
+ if cn < 1e-8:
876
+ return np.array([0., 0., 0., 1.], np.float32) if dot > 0 \
877
+ else np.array([1., 0., 0., 0.], np.float32)
878
+ axis = cross / cn
879
+ angle = np.arctan2(cn, dot)
880
+ s = np.sin(angle * 0.5)
881
+ return np.array([axis[0]*s, axis[1]*s, axis[2]*s, np.cos(angle*0.5)], np.float32)
882
+
883
+
884
+ def _quat_mul(q1: np.ndarray, q2: np.ndarray) -> np.ndarray:
885
+ """Hamilton product of two [x,y,z,w] quaternions."""
886
+ x1,y1,z1,w1 = q1; x2,y2,z2,w2 = q2
887
+ return np.array([
888
+ w1*x2 + x1*w2 + y1*z2 - z1*y2,
889
+ w1*y2 - x1*z2 + y1*w2 + z1*x2,
890
+ w1*z2 + x1*y2 - y1*x2 + z1*w2,
891
+ w1*w2 - x1*x2 - y1*y2 - z1*z2,
892
+ ], np.float32)
893
+
894
+
895
+ def _quat_inv(q: np.ndarray) -> np.ndarray:
896
+ return np.array([-q[0], -q[1], -q[2], q[3]], np.float32)
897
+
898
+
899
+ def _quat_rotate(q: np.ndarray, v: np.ndarray) -> np.ndarray:
900
+ """Rotate vector v by quaternion q."""
901
+ qv = np.array([v[0], v[1], v[2], 0.], np.float32)
902
+ return _quat_mul(_quat_mul(q, qv), _quat_inv(q))[:3]
903
+
904
+
905
+ def positions_to_local_quats(positions: np.ndarray,
906
+ t_pose_joints: np.ndarray,
907
+ parents: list) -> np.ndarray:
908
+ """
909
+ Derive per-joint local quaternions from world-space joint positions.
910
+ positions : (n_frames, n_joints, 3)
911
+ t_pose_joints : (n_joints, 3) β€” SMPL T-pose joints in same scale/space
912
+ parents : list of length n_joints, parent index (-1 for root)
913
+ Returns : (n_frames, n_joints, 4) XYZW local quaternions
914
+ """
915
+ n_frames, n_joints, _ = positions.shape
916
+ quats = np.zeros((n_frames, n_joints, 4), np.float32)
917
+ quats[:, :, 3] = 1.0 # default identity
918
+
919
+ # Compute global quats first, then convert to local
920
+ global_quats = np.zeros_like(quats)
921
+ global_quats[:, :, 3] = 1.0
922
+
923
+ for j in range(n_joints):
924
+ p = parents[j]
925
+ if p < 0:
926
+ # Root: no rotation relative to world (translation handles it)
927
+ global_quats[:, j] = [0, 0, 0, 1]
928
+ continue
929
+
930
+ # T-pose parent→child bone direction
931
+ tp_dir = t_pose_joints[j] - t_pose_joints[p]
932
+ tp_len = np.linalg.norm(tp_dir)
933
+ if tp_len < 1e-6:
934
+ continue
935
+ tp_dir /= tp_len
936
+
937
+ for f in range(n_frames):
938
+ an_dir = positions[f, j] - positions[f, p]
939
+ an_len = np.linalg.norm(an_dir)
940
+ if an_len < 1e-6:
941
+ global_quats[f, j] = global_quats[f, p]
942
+ continue
943
+ an_dir /= an_len
944
+ # Global rotation = parent_global ∘ local
945
+ # We want global bone direction to match an_dir
946
+ # global_bone_tpose = rotate(global_parent, tp_dir_in_parent_space)
947
+ # For SMPL T-pose, bone dirs are in world space already
948
+ gq = _quat_between(tp_dir, an_dir)
949
+ global_quats[f, j] = gq
950
+
951
+ # Convert global β†’ local (local = inv_parent_global ∘ global)
952
+ for j in range(n_joints):
953
+ p = parents[j]
954
+ if p < 0:
955
+ quats[:, j] = global_quats[:, j]
956
+ else:
957
+ for f in range(n_frames):
958
+ quats[f, j] = _quat_mul(_quat_inv(global_quats[f, p]),
959
+ global_quats[f, j])
960
+
961
+ return quats
962
+
963
+
964
+ # ══════════════════════════════════════════════════════════════════════════════
965
+ # Animated GLB export
966
+ # ══════════════════════════════════════════════════════════════════════════════
967
+
968
+ def export_animated_glb(verts, faces, uv, texture_img,
969
+ joints, # (24, 3) T-pose joint world positions
970
+ skin_weights, # (N_verts, 24)
971
+ joint_quats, # (n_frames, 24, 4) XYZW local quaternions
972
+ root_trans, # (n_frames, 3) world translation of root
973
+ fps: int,
974
+ out_path: str):
975
+ """
976
+ Export fully animated rigged GLB.
977
+ Skeleton + skin weights identical to export_rigged_glb;
978
+ adds a GLTF animation with per-joint rotation channels + root translation.
979
+ """
980
+ import pygltflib
981
+ from pygltflib import (GLTF2, Scene, Node, Mesh, Primitive, Accessor,
982
+ BufferView, Buffer, Material, Texture,
983
+ Image as GImage, Sampler, Skin, Asset,
984
+ Animation, AnimationChannel, AnimationChannelTarget,
985
+ AnimationSampler)
986
+ from pygltflib import (ARRAY_BUFFER, ELEMENT_ARRAY_BUFFER, FLOAT,
987
+ UNSIGNED_INT, UNSIGNED_SHORT, LINEAR,
988
+ LINEAR_MIPMAP_LINEAR, REPEAT, SCALAR, VEC2,
989
+ VEC3, VEC4, MAT4)
990
+
991
+ n_frames, n_joints_anim, _ = joint_quats.shape
992
+ n_joints = len(joints)
993
+
994
+ gltf = GLTF2()
995
+ gltf.asset = Asset(version="2.0", generator="rig_stage.py/animated")
996
+ blobs = []
997
+
998
+ def _add(data: np.ndarray, comp, acc_type, target=None,
999
+ set_min_max=False):
1000
+ b = data.tobytes()
1001
+ pad = (4 - len(b) % 4) % 4
1002
+ off = sum(len(x) for x in blobs)
1003
+ blobs.append(b + b"\x00" * pad)
1004
+ bv = len(gltf.bufferViews)
1005
+ gltf.bufferViews.append(BufferView(buffer=0, byteOffset=off,
1006
+ byteLength=len(b), target=target))
1007
+ ac = len(gltf.accessors)
1008
+ flat = data.flatten().astype(np.float32)
1009
+ kw = {}
1010
+ if set_min_max:
1011
+ kw = {"min": [float(flat.min())], "max": [float(flat.max())]}
1012
+ gltf.accessors.append(Accessor(
1013
+ bufferView=bv, byteOffset=0, componentType=comp,
1014
+ type=acc_type, count=len(data), **kw))
1015
+ return ac
1016
+
1017
+ # ── Mesh geometry ──────────────────────────────────────────────────────
1018
+ pos_acc = _add(verts.astype(np.float32), FLOAT, VEC3, ARRAY_BUFFER)
1019
+
1020
+ v0,v1,v2 = verts[faces[:,0]], verts[faces[:,1]], verts[faces[:,2]]
1021
+ fn = np.cross(v1-v0, v2-v0)
1022
+ fn /= (np.linalg.norm(fn, axis=1, keepdims=True) + 1e-8)
1023
+ vn = np.zeros_like(verts)
1024
+ for i in range(3): np.add.at(vn, faces[:,i], fn)
1025
+ vn /= (np.linalg.norm(vn, axis=1, keepdims=True) + 1e-8)
1026
+ nor_acc = _add(vn.astype(np.float32), FLOAT, VEC3, ARRAY_BUFFER)
1027
+
1028
+ if uv is None: uv = np.zeros((len(verts), 2), np.float32)
1029
+ uv_acc = _add(uv.astype(np.float32), FLOAT, VEC2, ARRAY_BUFFER)
1030
+ idx_acc = _add(faces.astype(np.uint32).flatten(), UNSIGNED_INT,
1031
+ SCALAR, ELEMENT_ARRAY_BUFFER)
1032
+
1033
+ top4_idx = np.argsort(-skin_weights, axis=1)[:, :4].astype(np.uint16)
1034
+ top4_w = np.take_along_axis(skin_weights, top4_idx.astype(np.int64), axis=1).astype(np.float32)
1035
+ top4_w /= top4_w.sum(axis=1, keepdims=True).clip(1e-8, None)
1036
+ j_acc = _add(top4_idx, UNSIGNED_SHORT, "VEC4", ARRAY_BUFFER)
1037
+ w_acc = _add(top4_w, FLOAT, "VEC4", ARRAY_BUFFER)
1038
+
1039
+ # ── Texture ────────────────────────────────────────────────────────────
1040
+ if texture_img is not None:
1041
+ import io
1042
+ buf = io.BytesIO(); texture_img.save(buf, format="PNG"); ib = buf.getvalue()
1043
+ off = sum(len(x) for x in blobs); pad2 = (4 - len(ib) % 4) % 4
1044
+ blobs.append(ib + b"\x00" * pad2)
1045
+ gltf.bufferViews.append(BufferView(buffer=0, byteOffset=off, byteLength=len(ib)))
1046
+ gltf.images.append(GImage(mimeType="image/png", bufferView=len(gltf.bufferViews)-1))
1047
+ gltf.samplers.append(Sampler(magFilter=LINEAR, minFilter=LINEAR_MIPMAP_LINEAR,
1048
+ wrapS=REPEAT, wrapT=REPEAT))
1049
+ gltf.textures.append(Texture(sampler=0, source=0))
1050
+ gltf.materials.append(Material(name="body",
1051
+ pbrMetallicRoughness={"baseColorTexture": {"index": 0},
1052
+ "metallicFactor": 0.0, "roughnessFactor": 0.8},
1053
+ doubleSided=True))
1054
+ else:
1055
+ gltf.materials.append(Material(name="body", doubleSided=True))
1056
+
1057
+ prim = Primitive(
1058
+ attributes={"POSITION": pos_acc, "NORMAL": nor_acc,
1059
+ "TEXCOORD_0": uv_acc, "JOINTS_0": j_acc, "WEIGHTS_0": w_acc},
1060
+ indices=idx_acc, material=0)
1061
+ gltf.meshes.append(Mesh(name="body", primitives=[prim]))
1062
+
1063
+ # ── Skeleton nodes ─────────────────────────────────────────────────────
1064
+ jnodes = []
1065
+ for i, (name, parent) in enumerate(zip(SMPL_JOINT_NAMES, SMPL_PARENTS)):
1066
+ t = joints[i].tolist() if parent == -1 else (joints[i] - joints[parent]).tolist()
1067
+ n = Node(name=name, translation=t, children=[])
1068
+ jnodes.append(len(gltf.nodes)); gltf.nodes.append(n)
1069
+ for i, p in enumerate(SMPL_PARENTS):
1070
+ if p != -1: gltf.nodes[jnodes[p]].children.append(jnodes[i])
1071
+
1072
+ ibms = np.stack([np.eye(4, dtype=np.float32) for _ in range(n_joints)])
1073
+ for i in range(n_joints): ibms[i, :3, 3] = -joints[i]
1074
+ ibm_acc = _add(ibms.astype(np.float32), FLOAT, MAT4)
1075
+ skin_idx = len(gltf.skins)
1076
+ gltf.skins.append(Skin(name="smpl_skin", skeleton=jnodes[0],
1077
+ joints=jnodes, inverseBindMatrices=ibm_acc))
1078
+
1079
+ mesh_node = len(gltf.nodes)
1080
+ gltf.nodes.append(Node(name="body_mesh", mesh=0, skin=skin_idx))
1081
+ root_node = len(gltf.nodes)
1082
+ gltf.nodes.append(Node(name="root", children=[jnodes[0], mesh_node]))
1083
+ gltf.scenes.append(Scene(name="Scene", nodes=[root_node]))
1084
+ gltf.scene = 0
1085
+
1086
+ # ── Animation ──────────────────────────────────────────────────────────
1087
+ dt = 1.0 / fps
1088
+ times = np.arange(n_frames, dtype=np.float32) * dt # (n_frames,)
1089
+ time_acc = _add(times, FLOAT, SCALAR, set_min_max=True)
1090
+
1091
+ channels, samplers = [], []
1092
+
1093
+ # Per-joint rotation tracks
1094
+ for j in range(min(n_joints_anim, n_joints)):
1095
+ q = joint_quats[:, j, :].astype(np.float32) # (n_frames, 4) XYZW
1096
+ q_acc = _add(q, FLOAT, VEC4)
1097
+ si = len(samplers)
1098
+ samplers.append(AnimationSampler(input=time_acc, output=q_acc,
1099
+ interpolation="LINEAR"))
1100
+ channels.append(AnimationChannel(
1101
+ sampler=si,
1102
+ target=AnimationChannelTarget(node=jnodes[j], path="rotation")))
1103
+
1104
+ # Root translation track
1105
+ if root_trans is not None:
1106
+ tr = root_trans.astype(np.float32) # (n_frames, 3)
1107
+ tr_acc = _add(tr, FLOAT, VEC3)
1108
+ si = len(samplers)
1109
+ samplers.append(AnimationSampler(input=time_acc, output=tr_acc,
1110
+ interpolation="LINEAR"))
1111
+ channels.append(AnimationChannel(
1112
+ sampler=si,
1113
+ target=AnimationChannelTarget(node=jnodes[0], path="translation")))
1114
+
1115
+ gltf.animations.append(Animation(name="mdm_motion",
1116
+ channels=channels, samplers=samplers))
1117
+
1118
+ # ── Finalise ───────────────────────────────────────────────────────────
1119
+ bin_data = b"".join(blobs)
1120
+ gltf.buffers.append(Buffer(byteLength=len(bin_data)))
1121
+ gltf.set_binary_blob(bin_data)
1122
+ gltf.save_binary(out_path)
1123
+ dur = times[-1] if len(times) else 0
1124
+ print(f"[rig] Animated GLB β†’ {out_path} "
1125
+ f"({os.path.getsize(out_path)//1024} KB, {n_frames} frames @ {fps}fps = {dur:.1f}s)")
1126
+
1127
+
1128
+ # ══════════════════════════════════════════════════════════════════════════════
1129
+ # Main pipeline
1130
+ # ══════════════════════════════════════════════════════════════════════════════
1131
+
1132
+ def run_rig_pipeline(glb_path: str, reference_image_path: str,
1133
+ out_dir: str, device: str = "cuda",
1134
+ export_fbx_flag: bool = True,
1135
+ mdm_prompt: str = "",
1136
+ mdm_n_frames: int = 120,
1137
+ mdm_fps: int = 20) -> dict:
1138
+ import trimesh
1139
+ os.makedirs(out_dir, exist_ok=True)
1140
+ result = {"rigged_glb": None, "animated_glb": None, "fbx": None,
1141
+ "smpl_params": None, "status": "", "phases": {}}
1142
+
1143
+ try:
1144
+ # ── load TripoSG mesh ─────────────────────────────────────────────
1145
+ print("[rig] Loading TripoSG mesh...")
1146
+ scene = trimesh.load(glb_path, force="scene")
1147
+ if isinstance(scene, trimesh.Scene):
1148
+ geom = list(scene.geometry.values())
1149
+ mesh = trimesh.util.concatenate(geom) if len(geom)>1 else geom[0]
1150
+ else:
1151
+ mesh = scene
1152
+ verts = np.array(mesh.vertices, dtype=np.float32)
1153
+ faces = np.array(mesh.faces, dtype=np.int32)
1154
+
1155
+ # UV + texture: try source geoms before concatenation (more reliable)
1156
+ uv, tex = None, None
1157
+ src_geoms = list(scene.geometry.values()) if isinstance(scene, trimesh.Scene) else [scene]
1158
+ for g in src_geoms:
1159
+ if not hasattr(g.visual, "uv") or g.visual.uv is None:
1160
+ continue
1161
+ try:
1162
+ candidate_uv = np.array(g.visual.uv, dtype=np.float32)
1163
+ if len(candidate_uv) == len(verts):
1164
+ uv = candidate_uv
1165
+ mat = getattr(g.visual, "material", None)
1166
+ if mat is not None:
1167
+ for attr in ("image", "baseColorTexture", "diffuse"):
1168
+ img = getattr(mat, attr, None)
1169
+ if img is not None:
1170
+ from PIL import Image as _PILImage
1171
+ tex = img if isinstance(img, _PILImage.Image) else None
1172
+ break
1173
+ break
1174
+ except Exception:
1175
+ pass
1176
+ if uv is None:
1177
+ print("[rig] WARNING: UV not found or vertex count mismatch β€” mesh will be untextured")
1178
+ print(f"[rig] Mesh: {len(verts)} verts, {len(faces)} faces, "
1179
+ f"UV={'yes' if uv is not None else 'no'}, "
1180
+ f"texture={'yes' if tex is not None else 'no'}")
1181
+
1182
+ # ── Phase 1: multi-view beta averaging ───────────────────────────
1183
+ print("\n[rig] ── Phase 1: multi-view beta averaging ──")
1184
+ betas, hmr2_results = estimate_betas_multiview(VIEW_PATHS, reference_image_path, device)
1185
+ result["phases"]["p1_betas"] = betas.tolist()
1186
+
1187
+ # ── Phase 2: silhouette fitting ───────────────────────────────────
1188
+ print("\n[rig] ── Phase 2: silhouette fitting ──")
1189
+ betas = fit_betas_silhouette(betas, VIEW_PATHS)
1190
+ result["phases"]["p2_betas"] = betas.tolist()
1191
+
1192
+ # ── Phase 3: multi-view joint triangulation ───────────────────────
1193
+ print("\n[rig] ── Phase 3: multi-view joint triangulation ──")
1194
+ tri_joints = triangulate_joints_multiview(hmr2_results)
1195
+ result["phases"]["p3_triangulated"] = tri_joints is not None
1196
+
1197
+ # ── build SMPL T-pose with refined betas ──────────────────────────
1198
+ print("\n[rig] Building SMPL T-pose...")
1199
+ smpl_v, smpl_f, smpl_j, smpl_w = get_smpl_tpose(betas)
1200
+
1201
+ # Override with triangulated joints if available
1202
+ if tri_joints is not None:
1203
+ # Triangulated joints are in render-normalised space; convert to SMPL scale
1204
+ _, _, scale, _ = _smpl_to_render_space(smpl_v.copy(), smpl_j.copy())
1205
+ smpl_j = tri_joints / scale # back to SMPL metric space
1206
+ print("[rig] Using triangulated skeleton joints.")
1207
+
1208
+ # ── align TripoSG mesh to SMPL ────────────────────────────────────
1209
+ verts_aligned = align_mesh_to_smpl(verts, smpl_v, smpl_j)
1210
+
1211
+ # ── skinning weight transfer ──────────────────────────────────────
1212
+ print("[rig] Transferring skinning weights...")
1213
+ skin_w = transfer_skinning(smpl_v, smpl_w, verts_aligned)
1214
+
1215
+ # ── export rigged GLB ─────────────────────────────────────────────
1216
+ rigged_glb = os.path.join(out_dir, "rigged.glb")
1217
+ export_rigged_glb(verts_aligned, faces, uv, tex, smpl_j, skin_w, rigged_glb)
1218
+ result["rigged_glb"] = rigged_glb
1219
+
1220
+ # ── export FBX ────────────────────────────────────────────────────
1221
+ if export_fbx_flag:
1222
+ fbx = os.path.join(out_dir, "rigged.fbx")
1223
+ result["fbx"] = fbx if export_fbx(rigged_glb, fbx) else None
1224
+
1225
+ # ── MDM animation ─────────────────────────────────────────────────
1226
+ if mdm_prompt.strip():
1227
+ print(f"\n[rig] ── MDM animation: {mdm_prompt!r} ({mdm_n_frames} frames) ──")
1228
+ mdm_out = generate_motion_mdm(mdm_prompt, n_frames=mdm_n_frames,
1229
+ fps=mdm_fps, device=device)
1230
+ if mdm_out is not None:
1231
+ pos = mdm_out["positions"] # (n_frames, 22, 3)
1232
+ actual_frames = pos.shape[0]
1233
+
1234
+ # Align MDM joint positions to SMPL scale/space
1235
+ # MDM outputs in metres roughly matching SMPL metric
1236
+ # Scale so pelvis height matches our SMPL pelvis
1237
+ mdm_pelvis_h = float(np.median(pos[:, 0, 1]))
1238
+ smpl_pelvis_h = float(smpl_j[0, 1])
1239
+ if abs(mdm_pelvis_h) > 1e-4:
1240
+ pos = pos * (smpl_pelvis_h / mdm_pelvis_h)
1241
+
1242
+ # FK inversion: positions β†’ local quaternions for joints 0-21
1243
+ t_pose_22 = smpl_j[:22]
1244
+ quats_22 = positions_to_local_quats(pos, t_pose_22, _MDM_PARENTS)
1245
+ # Pad to 24 joints (SMPL hands = identity)
1246
+ quats_24 = np.zeros((actual_frames, 24, 4), np.float32)
1247
+ quats_24[:, :, 3] = 1.0
1248
+ quats_24[:, :22, :] = quats_22
1249
+
1250
+ # Root translation: MDM root XZ + SMPL Y offset
1251
+ root_trans = pos[:, 0, :].copy() # (n_frames, 3)
1252
+
1253
+ anim_glb = os.path.join(out_dir, "animated.glb")
1254
+ export_animated_glb(
1255
+ verts_aligned, faces, uv, tex,
1256
+ smpl_j, skin_w,
1257
+ quats_24, root_trans, mdm_fps, anim_glb
1258
+ )
1259
+ result["animated_glb"] = anim_glb
1260
+ print(f"[rig] MDM animation complete β†’ {anim_glb}")
1261
+ else:
1262
+ print("[rig] MDM generation failed β€” static GLB only")
1263
+
1264
+ result["smpl_params"] = {
1265
+ "betas": betas.tolist(),
1266
+ "p1_sources": len(hmr2_results),
1267
+ "p3_triangulated": tri_joints is not None,
1268
+ }
1269
+ p3_note = " + triangulated skeleton" if tri_joints is not None else ""
1270
+ fbx_note = " + FBX" if result["fbx"] else ""
1271
+ anim_note = f" + MDM({mdm_n_frames}f)" if result.get("animated_glb") else ""
1272
+ result["status"] = (
1273
+ f"Rigged ({len(hmr2_results)} views used{p3_note}{fbx_note}{anim_note}). "
1274
+ f"{len(verts)} verts, 24 joints."
1275
+ )
1276
+
1277
+ except Exception:
1278
+ err = traceback.format_exc()
1279
+ print(f"[rig] FAILED:\n{err}")
1280
+ result["status"] = f"Rigging failed:\n{err[-600:]}"
1281
+
1282
+ return result