/* * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 */ #include "AnimProcessing/Utility.h" #ifdef _WIN32 #pragma warning(push) #pragma warning(disable : 4623 4191 4686 4868 5219 4191 4355) #endif #include #include #include #ifdef _WIN32 #pragma warning(pop) #endif namespace py = pybind11; float strip_nan_inf(float x) noexcept { if (std::isnan(x)) return 0; if (std::isinf(x)) return 0; return x; } void correct_motion( py::array_t &rootTranslations, py::array_t &jointRotations, const py::array_t& rootTranslationsTarget, const py::array_t& jointRotationsTarget, const py::array_t& fullPoseMask, const py::array_t& leftHandMask, const py::array_t& rightHandMask, const py::array_t& leftFootMask, const py::array_t& rightFootMask, const py::array_t& rootMask, const py::array_t& contacts, const py::list& joint_parents, const py::list& joint_ref_translations, const py::list& joint_ref_rotations, int left_hand_idx, int right_hand_idx, int left_foot_idx, int right_foot_idx, float contact_threshold, float root_margin, bool has_double_ankle_joints ) { if(joint_parents.size() != joint_ref_translations.size()) { throw std::runtime_error("correct_motion python bindings: joint_parents and joint_ref_translations must have the same size"); } if(joint_parents.size() != joint_ref_rotations.size()) { throw std::runtime_error("correct_motion python bindings: joint_parents and joint_ref_rotations must have the same size"); } if(left_hand_idx < 0 || right_hand_idx < 0 || left_foot_idx < 0 || right_foot_idx < 0) { throw std::runtime_error("correct_motion python bindings: left_hand_idx, right_hand_idx, left_foot_idx, and right_foot_idx must be non-negative"); } if(left_hand_idx >= joint_parents.size() || right_hand_idx >= joint_parents.size() || left_foot_idx >= joint_parents.size() || right_foot_idx >= joint_parents.size()) { throw std::runtime_error("correct_motion python bindings: left_hand_idx, right_hand_idx, left_foot_idx, and right_foot_idx must be less than the number of joints"); } std::vector defaultPose(joint_parents.size()); for (size_t i = 0; i < joint_ref_translations.size(); ++i) { if (!py::isinstance(joint_ref_translations[i])) { throw std::runtime_error("correct_motion python bindings: Expected joint_ref_translations to be a list of lists"); } py::list inner_list = joint_ref_translations[i].cast(); if (inner_list.size() != 3) { throw std::runtime_error("correct_motion python bindings: Expected joint_ref_translations to be a list of lists, length 3"); } if ( !py::isinstance(inner_list[0]) || !py::isinstance(inner_list[1]) || !py::isinstance(inner_list[2]) ) { throw std::runtime_error("correct_motion python bindings: Expected joint_ref_translations to be a list of lists, length 3, float values"); } if (!py::isinstance(joint_ref_rotations[i])) { throw std::runtime_error("correct_motion python bindings: Expected joint_ref_rotations to be a list of lists"); } py::list inner_list_rot = joint_ref_rotations[i].cast(); if (inner_list_rot.size() != 4) { throw std::runtime_error("correct_motion python bindings: Expected joint_ref_rotations to be a list of lists, length 4"); } if ( !py::isinstance(inner_list_rot[0]) || !py::isinstance(inner_list_rot[1]) || !py::isinstance(inner_list_rot[2]) || !py::isinstance(inner_list_rot[3]) ) { throw std::runtime_error("correct_motion python bindings: Expected joint_ref_rotations to be a list of lists, length 4, float values"); } defaultPose[i].SetTranslation(Math::Vector( inner_list[0].cast(), inner_list[1].cast(), inner_list[2].cast())); defaultPose[i].SetRotation(Math::Quaternion( inner_list_rot[0].cast(), inner_list_rot[1].cast(), inner_list_rot[2].cast(), inner_list_rot[3].cast() )); } std::vector joint_parents_vec(joint_parents.size()); for (size_t i = 0; i < joint_parents.size(); ++i) { if (!py::isinstance(joint_parents[i])) { throw std::runtime_error("correct_motion python bindings: Expected joint_parents to be a list of ints"); } joint_parents_vec[i] = joint_parents[i].cast(); if (joint_parents_vec[i] >= (int)joint_parents.size()) { throw std::runtime_error("correct_motion python bindings: joint_parents must be a list of ints, and all values must be less than the number of joints"); } } size_t num_joints = defaultPose.size(); size_t gen_length = fullPoseMask.size(); if( leftHandMask.size() != (int)gen_length || rightHandMask.size() != (int)gen_length || leftFootMask.size() != (int)gen_length || rightFootMask.size() != (int)gen_length || rootMask.size() != (int)gen_length ) { throw std::runtime_error("correct_motion python bindings: all masks must have the same size"); } if(rootTranslations.size() != 3 * (int)gen_length) { throw std::runtime_error("correct_motion python bindings: rootTranslations has the wrong size"); } if(jointRotations.size() != 4 * (int)num_joints * (int)gen_length) { throw std::runtime_error("correct_motion python bindings: jointRotations has the wrong size"); } if(rootTranslationsTarget.size() != 3 * (int)gen_length) { throw std::runtime_error("correct_motion python bindings: rootTranslationsTarget has the wrong size"); } if(jointRotationsTarget.size() != 4 * (int)num_joints * (int)gen_length) { throw std::runtime_error("correct_motion python bindings: jointRotationsTarget has the wrong size"); } std::vector endEffectorPins(4); endEffectorPins[0].jointIndex = left_hand_idx; endEffectorPins[0].hintOffset = Math::Vector(0.0f, 0.0f, -0.1f); endEffectorPins[1].jointIndex = right_hand_idx; endEffectorPins[1].hintOffset = Math::Vector(0.0f, 0.0f, -0.1f); endEffectorPins[2].jointIndex = left_foot_idx; endEffectorPins[2].hintOffset = Math::Vector(0.0f, 0.0f, 0.1f); endEffectorPins[3].jointIndex = right_foot_idx; endEffectorPins[3].hintOffset = Math::Vector(0.0f, 0.0f, 0.1f); endEffectorPins[0].contactMask.reserve(gen_length); endEffectorPins[1].contactMask.reserve(gen_length); endEffectorPins[2].contactMask.reserve(gen_length); endEffectorPins[3].contactMask.reserve(gen_length); for(size_t i = 0; i < gen_length; ++i) { endEffectorPins[0].contactMask.push_back((1.0f - fullPoseMask.at(i)) * leftHandMask.at(i)); endEffectorPins[1].contactMask.push_back((1.0f - fullPoseMask.at(i)) * rightHandMask.at(i)); endEffectorPins[2].contactMask.push_back((1.0f - fullPoseMask.at(i)) * leftFootMask.at(i)); endEffectorPins[3].contactMask.push_back((1.0f - fullPoseMask.at(i)) * rightFootMask.at(i)); } std::vector contactInfo(2); auto footTranslation = Animation::JointLocalToGlobal( joint_parents_vec, right_foot_idx, defaultPose ).GetTranslation(); contactInfo[0].jointIndex = right_foot_idx; contactInfo[0].hintOffset = Math::Vector(0.0f, 0.0f, 0.1f); contactInfo[0].minHeight = footTranslation.GetY(); footTranslation = Animation::JointLocalToGlobal( joint_parents_vec, left_foot_idx, defaultPose ).GetTranslation(); contactInfo[1].jointIndex = left_foot_idx; contactInfo[1].hintOffset = Math::Vector(0.0f, 0.0f, 0.1f); contactInfo[1].minHeight = footTranslation.GetY(); auto& rContacts = contactInfo[0].contactMask; auto& lContacts = contactInfo[1].contactMask; rContacts.resize(fullPoseMask.size()); lContacts.resize(fullPoseMask.size()); for (int i = 0; i < fullPoseMask.size(); ++i) { // don't flag it as a contact if it's been masked: rContacts[i] = rightFootMask.at(i) ? 0 : contacts.at(4 * i + 2); lContacts[i] = leftFootMask.at(i) ? 0 : contacts.at(4 * i + 0); // Flag the heel as a contact if the toe is a contact: rContacts[i] = std::min((rightFootMask.at(i) ? 0 : contacts.at(4 * i + 3)) + rContacts[i], 1.0f); lContacts[i] = std::min((leftFootMask.at(i) ? 0 : contacts.at(4 * i + 1)) + lContacts[i], 1.0f); } int left_toe_idx = -1; int right_toe_idx = -1; for(int i = 0; i < num_joints; ++i) { if(joint_parents_vec[i] == left_foot_idx) { left_toe_idx = i; } if(joint_parents_vec[i] == right_foot_idx) { right_toe_idx = i; } } if(left_toe_idx != -1 && right_toe_idx != -1) { auto toeTranslation = Animation::JointLocalToGlobal( joint_parents_vec, right_toe_idx, defaultPose ).GetTranslation(); contactInfo.resize(4); contactInfo[2].jointIndex = right_toe_idx; contactInfo[2].contactType = Animation::kOneBone; contactInfo[2].minHeight = toeTranslation.GetY(); contactInfo[3].jointIndex = left_toe_idx; contactInfo[3].contactType = Animation::kOneBone; contactInfo[3].minHeight = toeTranslation.GetY(); auto& rToeContacts = contactInfo[2].contactMask; auto& lToeContacts = contactInfo[3].contactMask; // fill up the ankle contacts: rToeContacts.resize(fullPoseMask.size()); lToeContacts.resize(fullPoseMask.size()); for (int i = 0; i < fullPoseMask.size(); ++i) { // don't flag it as a contact if it's been masked: rToeContacts[i] = rightFootMask.at(i) ? 0 : contacts.at(4 * i + 3); lToeContacts[i] = leftFootMask.at(i) ? 0 : contacts.at(4 * i + 1); } } auto setTransforms = [gen_length, num_joints]( std::vector< std::vector > &poses, const py::array_t &rootTranslations, const py::array_t &jointRotations ) { for (size_t f = 0; f < gen_length; ++f) { poses[f][0].SetTranslation({ strip_nan_inf(rootTranslations.at(3*f+0)), strip_nan_inf(rootTranslations.at(3*f+1)), strip_nan_inf(rootTranslations.at(3*f+2)) }); } for (size_t f = 0; f < gen_length; ++f) { for (size_t j = 0; j < num_joints; ++j) { // x y z w order: Math::Quaternion q( strip_nan_inf(jointRotations.at(4 * (num_joints * f + j) + 1)), strip_nan_inf(jointRotations.at(4 * (num_joints * f + j) + 2)), strip_nan_inf(jointRotations.at(4 * (num_joints * f + j) + 3)), strip_nan_inf(jointRotations.at(4 * (num_joints * f + j) + 0)) ); q.Normalize(); poses[f][j].SetRotation(q); } } }; std::vector< std::vector > posesFixed(gen_length, defaultPose); setTransforms(posesFixed, rootTranslations, jointRotations); std::vector< std::vector > posesTarget(gen_length, defaultPose); setTransforms(posesTarget, rootTranslationsTarget, jointRotationsTarget); std::vector fullPoseMask_vec; std::vector rootMask_vec; for (size_t f = 0; f < gen_length; ++f) { fullPoseMask_vec.push_back(fullPoseMask.at(f)); rootMask_vec.push_back(rootMask.at(f)); } Animation::CorrectMotion( posesFixed, posesTarget, fullPoseMask_vec, rootMask_vec, contactInfo, endEffectorPins, joint_parents_vec, defaultPose, contact_threshold, root_margin, has_double_ankle_joints ); for (size_t f = 0; f < gen_length; ++f) { auto t = posesFixed[f][0].GetTranslation(); rootTranslations.mutable_at(3*f+0) = t.GetX(); rootTranslations.mutable_at(3*f+1) = t.GetY(); rootTranslations.mutable_at(3*f+2) = t.GetZ(); } for (size_t f = 0; f < gen_length; ++f) { for (size_t j = 0; j < num_joints; ++j) { auto q = posesFixed[f][j].GetRotation(); // w x y z order jointRotations.mutable_at(4 * (num_joints * f + j) + 0) = ((float*)&q)[3]; jointRotations.mutable_at(4 * (num_joints * f + j) + 1) = ((float*)&q)[0]; jointRotations.mutable_at(4 * (num_joints * f + j) + 2) = ((float*)&q)[1]; jointRotations.mutable_at(4 * (num_joints * f + j) + 3) = ((float*)&q)[2]; } } } PYBIND11_MODULE(_motion_correction, m) { m.doc() = "Motion Correction Python bindings"; m.def("correct_motion", &correct_motion); }