Wojtekb30's picture
Upload 11 files
316a030 verified
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch.nn.functional as F
from pathlib import Path
from safetensors.torch import load_file
# Import model architecture
from rvq_model import MotionRVQ_VAE
# ==========================================
# 1. Configuration
# ==========================================
BASE_DIR = Path(__file__).resolve().parent
FILE_TO_TEST = BASE_DIR / "000001.npy"
WEIGHTS_PATH = BASE_DIR / "motion_rvq_weights.safetensors"
# ==========================================
# 2. Model Initialization
# ==========================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MotionRVQ_VAE().to(device)
try:
state_dict = load_file(str(WEIGHTS_PATH), device=str(device))
model.load_state_dict(state_dict)
print(f"Successfully loaded weights from {WEIGHTS_PATH}!")
except FileNotFoundError:
print(f"ERROR: Could not find file {WEIGHTS_PATH}.")
exit()
model.eval()
# ==========================================
# 3. Forward Pass (With Padding and Normalization)
# ==========================================
original_data = np.load(FILE_TO_TEST)
T_orig = original_data.shape[0]
# Load normalization vectors
mean = np.load(BASE_DIR / "Mean.npy")
std = np.load(BASE_DIR / "Std.npy")
# Padding for stride=4 compression
pad_len = (4 - (T_orig % 4)) % 4
if pad_len > 0:
last_frame = original_data[-1:]
padded_data = np.concatenate([original_data, np.repeat(last_frame, pad_len, axis=0)], axis=0)
else:
padded_data = original_data
padded_data = (padded_data - mean) / std
input_tensor = torch.from_numpy(padded_data).float().unsqueeze(0).permute(0, 2, 1).to(device)
with torch.no_grad():
# 1. Get tokens from all levels
z = model.encoder(input_tensor)
_, tokens, _ = model.rvq(z)
# 2. Function for "partial" decoding
def decode_from_levels(num_levels):
z_q_partial = 0
for i in range(num_levels):
# Get indices only for level "i"
indices = tokens[:, i, :]
quantizer = model.rvq.quantizers[i]
# Convert token id (e.g. 145) into its 1024-codebook vector
level_z_q = F.embedding(indices, quantizer.embedding)
level_z_q = level_z_q.permute(0, 2, 1) # Shape expected by decoder
# Add residual vector to the running latent
z_q_partial = z_q_partial + level_z_q
return model.decoder(z_q_partial)
# Generate motion using only level 1 and all 4 levels
recon_tensor_1_lvl = decode_from_levels(1)
recon_tensor_4_lvl = decode_from_levels(4)
# Back to NumPy
recon_1_lvl = recon_tensor_1_lvl.squeeze(0).permute(1, 0).cpu().numpy()[:T_orig, :]
recon_4_lvl = recon_tensor_4_lvl.squeeze(0).permute(1, 0).cpu().numpy()[:T_orig, :]
# De-normalization
recon_1_lvl = (recon_1_lvl * std) + mean
recon_4_lvl = (recon_4_lvl * std) + mean
# ==========================================
# 4. Skeleton extraction utility
# ==========================================
def get_3d_joints(data_263):
num_frames = data_263.shape[0]
joints = np.zeros((num_frames, 22, 3))
for i in range(num_frames):
root_y = data_263[i, 3]
joints[i, 0] = [0, root_y, 0]
local_positions = data_263[i, 4:67].reshape(21, 3)
joints[i, 1:] = local_positions + [0, root_y, 0]
return joints
joints_orig = get_3d_joints(original_data)
joints_1_lvl = get_3d_joints(recon_1_lvl)
joints_4_lvl = get_3d_joints(recon_4_lvl)
# ==========================================
# 5. Three-panel visualization
# ==========================================
kinematic_tree = [
[0, 1, 4, 7, 10], [0, 2, 5, 8, 11], [0, 3, 6, 9, 12, 15],
[9, 13, 16, 18, 20], [9, 14, 17, 19, 21]
]
fig = plt.figure(figsize=(15, 5))
ax1 = fig.add_subplot(131, projection='3d')
ax2 = fig.add_subplot(132, projection='3d')
ax3 = fig.add_subplot(133, projection='3d')
def update(frame):
for ax in [ax1, ax2, ax3]:
ax.clear()
ax.set_xlim(-1, 1); ax.set_ylim(-1, 1); ax.set_zlim(0, 2)
ax.view_init(elev=10., azim=-90)
ax.axis('off')
ax1.set_title(f"ORIGINAL\n(Frame {frame})")
ax2.set_title("RECONSTRUCTION: 1 LEVEL\n(Coarse tokens only)")
ax3.set_title("RECONSTRUCTION: 4 LEVELS\n(Full RVQ detail)")
for chain in kinematic_tree:
# Original (Blue)
ax1.plot(joints_orig[frame, chain, 0], joints_orig[frame, chain, 2], joints_orig[frame, chain, 1],
linewidth=3, marker='o', markersize=4, color='blue')
# 1 Level (Orange)
ax2.plot(joints_1_lvl[frame, chain, 0], joints_1_lvl[frame, chain, 2], joints_1_lvl[frame, chain, 1],
linewidth=3, marker='o', markersize=4, color='orange')
# 4 Levels (Red)
ax3.plot(joints_4_lvl[frame, chain, 0], joints_4_lvl[frame, chain, 2], joints_4_lvl[frame, chain, 1],
linewidth=3, marker='o', markersize=4, color='red')
ani = animation.FuncAnimation(fig, update, frames=T_orig, interval=50, repeat=True)
plt.tight_layout()
plt.show()