| import torch
|
| import numpy as np
|
| import matplotlib.pyplot as plt
|
| import matplotlib.animation as animation
|
| import torch.nn.functional as F
|
| from pathlib import Path
|
| from safetensors.torch import load_file
|
|
|
|
|
| from rvq_model import MotionRVQ_VAE
|
|
|
|
|
|
|
|
|
| BASE_DIR = Path(__file__).resolve().parent
|
| FILE_TO_TEST = BASE_DIR / "000001.npy"
|
| WEIGHTS_PATH = BASE_DIR / "motion_rvq_weights.safetensors"
|
|
|
|
|
|
|
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| model = MotionRVQ_VAE().to(device)
|
|
|
| try:
|
| state_dict = load_file(str(WEIGHTS_PATH), device=str(device))
|
| model.load_state_dict(state_dict)
|
| print(f"Successfully loaded weights from {WEIGHTS_PATH}!")
|
| except FileNotFoundError:
|
| print(f"ERROR: Could not find file {WEIGHTS_PATH}.")
|
| exit()
|
|
|
| model.eval()
|
|
|
|
|
|
|
|
|
| original_data = np.load(FILE_TO_TEST)
|
| T_orig = original_data.shape[0]
|
|
|
|
|
| mean = np.load(BASE_DIR / "Mean.npy")
|
| std = np.load(BASE_DIR / "Std.npy")
|
|
|
|
|
| pad_len = (4 - (T_orig % 4)) % 4
|
| if pad_len > 0:
|
| last_frame = original_data[-1:]
|
| padded_data = np.concatenate([original_data, np.repeat(last_frame, pad_len, axis=0)], axis=0)
|
| else:
|
| padded_data = original_data
|
|
|
| padded_data = (padded_data - mean) / std
|
| input_tensor = torch.from_numpy(padded_data).float().unsqueeze(0).permute(0, 2, 1).to(device)
|
|
|
| with torch.no_grad():
|
|
|
| z = model.encoder(input_tensor)
|
| _, tokens, _ = model.rvq(z)
|
|
|
|
|
| def decode_from_levels(num_levels):
|
| z_q_partial = 0
|
| for i in range(num_levels):
|
|
|
| indices = tokens[:, i, :]
|
| quantizer = model.rvq.quantizers[i]
|
|
|
|
|
| level_z_q = F.embedding(indices, quantizer.embedding)
|
| level_z_q = level_z_q.permute(0, 2, 1)
|
|
|
|
|
| z_q_partial = z_q_partial + level_z_q
|
|
|
| return model.decoder(z_q_partial)
|
|
|
|
|
| recon_tensor_1_lvl = decode_from_levels(1)
|
| recon_tensor_4_lvl = decode_from_levels(4)
|
|
|
|
|
| recon_1_lvl = recon_tensor_1_lvl.squeeze(0).permute(1, 0).cpu().numpy()[:T_orig, :]
|
| recon_4_lvl = recon_tensor_4_lvl.squeeze(0).permute(1, 0).cpu().numpy()[:T_orig, :]
|
|
|
|
|
| recon_1_lvl = (recon_1_lvl * std) + mean
|
| recon_4_lvl = (recon_4_lvl * std) + mean
|
|
|
|
|
|
|
|
|
| def get_3d_joints(data_263):
|
| num_frames = data_263.shape[0]
|
| joints = np.zeros((num_frames, 22, 3))
|
| for i in range(num_frames):
|
| root_y = data_263[i, 3]
|
| joints[i, 0] = [0, root_y, 0]
|
| local_positions = data_263[i, 4:67].reshape(21, 3)
|
| joints[i, 1:] = local_positions + [0, root_y, 0]
|
| return joints
|
|
|
| joints_orig = get_3d_joints(original_data)
|
| joints_1_lvl = get_3d_joints(recon_1_lvl)
|
| joints_4_lvl = get_3d_joints(recon_4_lvl)
|
|
|
|
|
|
|
|
|
| kinematic_tree = [
|
| [0, 1, 4, 7, 10], [0, 2, 5, 8, 11], [0, 3, 6, 9, 12, 15],
|
| [9, 13, 16, 18, 20], [9, 14, 17, 19, 21]
|
| ]
|
|
|
| fig = plt.figure(figsize=(15, 5))
|
| ax1 = fig.add_subplot(131, projection='3d')
|
| ax2 = fig.add_subplot(132, projection='3d')
|
| ax3 = fig.add_subplot(133, projection='3d')
|
|
|
| def update(frame):
|
| for ax in [ax1, ax2, ax3]:
|
| ax.clear()
|
| ax.set_xlim(-1, 1); ax.set_ylim(-1, 1); ax.set_zlim(0, 2)
|
| ax.view_init(elev=10., azim=-90)
|
| ax.axis('off')
|
|
|
| ax1.set_title(f"ORIGINAL\n(Frame {frame})")
|
| ax2.set_title("RECONSTRUCTION: 1 LEVEL\n(Coarse tokens only)")
|
| ax3.set_title("RECONSTRUCTION: 4 LEVELS\n(Full RVQ detail)")
|
|
|
| for chain in kinematic_tree:
|
|
|
| ax1.plot(joints_orig[frame, chain, 0], joints_orig[frame, chain, 2], joints_orig[frame, chain, 1],
|
| linewidth=3, marker='o', markersize=4, color='blue')
|
|
|
| ax2.plot(joints_1_lvl[frame, chain, 0], joints_1_lvl[frame, chain, 2], joints_1_lvl[frame, chain, 1],
|
| linewidth=3, marker='o', markersize=4, color='orange')
|
|
|
| ax3.plot(joints_4_lvl[frame, chain, 0], joints_4_lvl[frame, chain, 2], joints_4_lvl[frame, chain, 1],
|
| linewidth=3, marker='o', markersize=4, color='red')
|
|
|
| ani = animation.FuncAnimation(fig, update, frames=T_orig, interval=50, repeat=True)
|
| plt.tight_layout()
|
| plt.show()
|
|
|