MeshForge / scripts /momask_server.py
Daankular's picture
Upload scripts/momask_server.py with huggingface_hub
cf582e0 verified
"""
momask_server.py
────────────────────────────────────────────────────────────────────────────
Lightweight Flask inference server wrapping MoMask text-to-motion generation.
Runs on the Vast.ai instance. Exposes POST /generate β†’ [T, 263] JSON.
Does NOT require SMPL body models β€” only the MoMask VQ-VAE checkpoints.
Deploy
──────
1. Upload this file to /root/momask_server.py on the instance
2. Install deps (see deploy_momask.sh)
3. Run: python /root/momask_server.py --port 8765
Endpoint
────────
POST /generate
Body: {"prompt": str, "num_frames": int, "seed": int}
Reply: {"motion": [[T, 263] as nested list], "num_frames": T, "fps": 20}
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import numpy as np
# ── Flask ──────────────────────────────────────────────────────────────────
try:
from flask import Flask, request, jsonify
except ImportError:
sys.exit("pip install flask")
app = Flask(__name__)
# ── Global model state ──────────────────────────────────────────────────────
_model = None
_mean = None
_std = None
_max_len = 196 # max HumanML3D frames (~9.8 s at 20 fps)
def _load_model(momask_root: str, device: str = "cuda"):
"""Load MoMask model + normalisation stats into global state."""
global _model, _mean, _std
sys.path.insert(0, momask_root)
import torch
from models.mask_transformer.transformer import MaskTransformer
from options.get_eval_option import get_opt
# Load options from checkpoint directory
opt_path = os.path.join(momask_root, "checkpoints", "t2m", "t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns",
"opt.txt")
opt = get_opt(opt_path, device=device)
# Load normalisation stats (from the HumanML3D dataset)
stat_dir = os.path.join(momask_root, "checkpoints", "t2m",
"t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns")
_mean = np.load(os.path.join(stat_dir, "meta", "mean.npy"))
_std = np.load(os.path.join(stat_dir, "meta", "std.npy"))
# Load the transformer + VQ-VAE
from models.mask_transformer.transformer import MaskTransformer
from models.vq.model import RVQVAE
import options.option_transformer as option_trans
args = option_trans.get_args_parser()
args = args.parse_args([])
args.dataname = "t2m"
args.res_name = "ter1"
args.nb_code = 512
args.code_dim = 512
args.output_emb_width = 512
args.nb_joints = 22
args.window_size = 64
args.down_t = 2
args.stride_t = 2
args.width = 512
args.depth = 3
args.dilation_growth_rate = 3
args.vq_act = "relu"
args.vq_norm = None
args.num_quantizers = 6
net = RVQVAE(args,
263,
args.nb_code,
args.code_dim,
args.output_emb_width,
args.down_t,
args.stride_t,
args.width,
args.depth,
args.dilation_growth_rate,
args.vq_act,
args.vq_norm)
# Load residual VQ-VAE weights
vqvae_ckpt = os.path.join(momask_root, "checkpoints", "t2m", "Comp_v6_KLD005",
"net_last.pth")
ckpt = torch.load(vqvae_ckpt, map_location="cpu")
net.load_state_dict(ckpt["net"], strict=True)
net.eval().to(device)
# Load mask transformer weights
trans_ckpt_dir = os.path.join(momask_root, "checkpoints", "t2m",
"t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns")
trans = MaskTransformer(code_dim=opt.code_dim,
cond_mode="text",
latent_dim=opt.latent_dim,
ff_size=opt.ff_size,
num_layers=opt.num_layers,
num_heads=opt.num_heads,
dropout=opt.dropout,
clip_dim=512,
cond_drop_prob=opt.cond_drop_prob,
clip_version=opt.clip_version,
opt=opt)
trans_ckpt = torch.load(os.path.join(trans_ckpt_dir, "net_last.pth"), map_location="cpu")
trans.load_state_dict(trans_ckpt["trans"], strict=True)
trans.eval().to(device)
_model = (net, trans, opt, device)
print(f"[momask_server] Model loaded on {device}")
def _generate(prompt: str, num_frames: int, seed: int) -> np.ndarray:
"""Run MoMask inference; return denormalised [T, 263] array."""
import torch
from utils.motion_process import recover_from_ric
net, trans, opt, device = _model
if seed >= 0:
torch.manual_seed(seed)
np.random.seed(seed)
T = min(int(num_frames), _max_len)
with torch.no_grad():
# CLIP text encoding
from models.mask_transformer.transformer import MaskTransformer
cond_vector = trans.encode_text([prompt]) # [1, 77, 512]
# MoMask iterative decoding
mids = trans.generate(cond_vector, T // 4, temperature=1.0, topk_filter_thres=0.9,
gsample=True, force_mask=False) # [1, T//4, nb_code]
# Decode token sequence β†’ motion features via RVQVAE decoder
motion = net.forward_decoder(mids) # [1, T, 263]
motion = motion[0].cpu().numpy() # [T, 263]
# Denormalise
motion = motion * _std + _mean
return motion.astype(np.float32)
# ── Routes ────────────────────────────────────────────────────────────────────
@app.route("/health", methods=["GET"])
def health():
return jsonify({"status": "ok", "model_loaded": _model is not None})
@app.route("/generate", methods=["POST"])
def generate():
body = request.get_json(force=True)
prompt = body.get("prompt", "a person walks forward")
num_frames = int(body.get("num_frames", 120))
seed = int(body.get("seed", -1))
if _model is None:
return jsonify({"error": "model not loaded"}), 503
try:
motion = _generate(prompt, num_frames, seed)
return jsonify({
"motion": motion.tolist(),
"num_frames": int(motion.shape[0]),
"fps": 20,
"prompt": prompt,
})
except Exception as e:
return jsonify({"error": str(e)}), 500
# ── Entry point ───────────────────────────────────────────────────────────────
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--momask-root", default="/root/momask-codes")
parser.add_argument("--port", type=int, default=8765)
parser.add_argument("--device", default="cuda")
parser.add_argument("--host", default="0.0.0.0")
args = parser.parse_args()
print(f"[momask_server] Loading model from {args.momask_root} ...")
_load_model(args.momask_root, args.device)
print(f"[momask_server] Listening on {args.host}:{args.port}")
app.run(host=args.host, port=args.port, threaded=False)