import torch import numpy as np import matplotlib.pyplot as plt import matplotlib.animation as animation import torch.nn.functional as F from pathlib import Path from safetensors.torch import load_file # Import model architecture from rvq_model import MotionRVQ_VAE # ========================================== # 1. Configuration # ========================================== BASE_DIR = Path(__file__).resolve().parent FILE_TO_TEST = BASE_DIR / "000001.npy" WEIGHTS_PATH = BASE_DIR / "motion_rvq_weights.safetensors" # ========================================== # 2. Model Initialization # ========================================== device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = MotionRVQ_VAE().to(device) try: state_dict = load_file(str(WEIGHTS_PATH), device=str(device)) model.load_state_dict(state_dict) print(f"Successfully loaded weights from {WEIGHTS_PATH}!") except FileNotFoundError: print(f"ERROR: Could not find file {WEIGHTS_PATH}.") exit() model.eval() # ========================================== # 3. Forward Pass (With Padding and Normalization) # ========================================== original_data = np.load(FILE_TO_TEST) T_orig = original_data.shape[0] # Load normalization vectors mean = np.load(BASE_DIR / "Mean.npy") std = np.load(BASE_DIR / "Std.npy") # Padding for stride=4 compression pad_len = (4 - (T_orig % 4)) % 4 if pad_len > 0: last_frame = original_data[-1:] padded_data = np.concatenate([original_data, np.repeat(last_frame, pad_len, axis=0)], axis=0) else: padded_data = original_data padded_data = (padded_data - mean) / std input_tensor = torch.from_numpy(padded_data).float().unsqueeze(0).permute(0, 2, 1).to(device) with torch.no_grad(): # 1. Get tokens from all levels z = model.encoder(input_tensor) _, tokens, _ = model.rvq(z) # 2. Function for "partial" decoding def decode_from_levels(num_levels): z_q_partial = 0 for i in range(num_levels): # Get indices only for level "i" indices = tokens[:, i, :] quantizer = model.rvq.quantizers[i] # Convert token id (e.g. 145) into its 1024-codebook vector level_z_q = F.embedding(indices, quantizer.embedding) level_z_q = level_z_q.permute(0, 2, 1) # Shape expected by decoder # Add residual vector to the running latent z_q_partial = z_q_partial + level_z_q return model.decoder(z_q_partial) # Generate motion using only level 1 and all 4 levels recon_tensor_1_lvl = decode_from_levels(1) recon_tensor_4_lvl = decode_from_levels(4) # Back to NumPy recon_1_lvl = recon_tensor_1_lvl.squeeze(0).permute(1, 0).cpu().numpy()[:T_orig, :] recon_4_lvl = recon_tensor_4_lvl.squeeze(0).permute(1, 0).cpu().numpy()[:T_orig, :] # De-normalization recon_1_lvl = (recon_1_lvl * std) + mean recon_4_lvl = (recon_4_lvl * std) + mean # ========================================== # 4. Skeleton extraction utility # ========================================== def get_3d_joints(data_263): num_frames = data_263.shape[0] joints = np.zeros((num_frames, 22, 3)) for i in range(num_frames): root_y = data_263[i, 3] joints[i, 0] = [0, root_y, 0] local_positions = data_263[i, 4:67].reshape(21, 3) joints[i, 1:] = local_positions + [0, root_y, 0] return joints joints_orig = get_3d_joints(original_data) joints_1_lvl = get_3d_joints(recon_1_lvl) joints_4_lvl = get_3d_joints(recon_4_lvl) # ========================================== # 5. Three-panel visualization # ========================================== kinematic_tree = [ [0, 1, 4, 7, 10], [0, 2, 5, 8, 11], [0, 3, 6, 9, 12, 15], [9, 13, 16, 18, 20], [9, 14, 17, 19, 21] ] fig = plt.figure(figsize=(15, 5)) ax1 = fig.add_subplot(131, projection='3d') ax2 = fig.add_subplot(132, projection='3d') ax3 = fig.add_subplot(133, projection='3d') def update(frame): for ax in [ax1, ax2, ax3]: ax.clear() ax.set_xlim(-1, 1); ax.set_ylim(-1, 1); ax.set_zlim(0, 2) ax.view_init(elev=10., azim=-90) ax.axis('off') ax1.set_title(f"ORIGINAL\n(Frame {frame})") ax2.set_title("RECONSTRUCTION: 1 LEVEL\n(Coarse tokens only)") ax3.set_title("RECONSTRUCTION: 4 LEVELS\n(Full RVQ detail)") for chain in kinematic_tree: # Original (Blue) ax1.plot(joints_orig[frame, chain, 0], joints_orig[frame, chain, 2], joints_orig[frame, chain, 1], linewidth=3, marker='o', markersize=4, color='blue') # 1 Level (Orange) ax2.plot(joints_1_lvl[frame, chain, 0], joints_1_lvl[frame, chain, 2], joints_1_lvl[frame, chain, 1], linewidth=3, marker='o', markersize=4, color='orange') # 4 Levels (Red) ax3.plot(joints_4_lvl[frame, chain, 0], joints_4_lvl[frame, chain, 2], joints_4_lvl[frame, chain, 1], linewidth=3, marker='o', markersize=4, color='red') ani = animation.FuncAnimation(fig, update, frames=T_orig, interval=50, repeat=True) plt.tight_layout() plt.show()