| """ |
| momask_server.py |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| Lightweight Flask inference server wrapping MoMask text-to-motion generation. |
| Runs on the Vast.ai instance. Exposes POST /generate β [T, 263] JSON. |
| |
| Does NOT require SMPL body models β only the MoMask VQ-VAE checkpoints. |
| |
| Deploy |
| ββββββ |
| 1. Upload this file to /root/momask_server.py on the instance |
| 2. Install deps (see deploy_momask.sh) |
| 3. Run: python /root/momask_server.py --port 8765 |
| |
| Endpoint |
| ββββββββ |
| POST /generate |
| Body: {"prompt": str, "num_frames": int, "seed": int} |
| Reply: {"motion": [[T, 263] as nested list], "num_frames": T, "fps": 20} |
| """ |
| from __future__ import annotations |
| import argparse |
| import json |
| import os |
| import sys |
|
|
| import numpy as np |
|
|
| |
| try: |
| from flask import Flask, request, jsonify |
| except ImportError: |
| sys.exit("pip install flask") |
|
|
| app = Flask(__name__) |
|
|
| |
| _model = None |
| _mean = None |
| _std = None |
| _max_len = 196 |
|
|
|
|
| def _load_model(momask_root: str, device: str = "cuda"): |
| """Load MoMask model + normalisation stats into global state.""" |
| global _model, _mean, _std |
|
|
| sys.path.insert(0, momask_root) |
|
|
| import torch |
| from models.mask_transformer.transformer import MaskTransformer |
| from options.get_eval_option import get_opt |
|
|
| |
| opt_path = os.path.join(momask_root, "checkpoints", "t2m", "t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns", |
| "opt.txt") |
| opt = get_opt(opt_path, device=device) |
|
|
| |
| stat_dir = os.path.join(momask_root, "checkpoints", "t2m", |
| "t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns") |
| _mean = np.load(os.path.join(stat_dir, "meta", "mean.npy")) |
| _std = np.load(os.path.join(stat_dir, "meta", "std.npy")) |
|
|
| |
| from models.mask_transformer.transformer import MaskTransformer |
| from models.vq.model import RVQVAE |
| import options.option_transformer as option_trans |
|
|
| args = option_trans.get_args_parser() |
| args = args.parse_args([]) |
| args.dataname = "t2m" |
| args.res_name = "ter1" |
| args.nb_code = 512 |
| args.code_dim = 512 |
| args.output_emb_width = 512 |
| args.nb_joints = 22 |
| args.window_size = 64 |
| args.down_t = 2 |
| args.stride_t = 2 |
| args.width = 512 |
| args.depth = 3 |
| args.dilation_growth_rate = 3 |
| args.vq_act = "relu" |
| args.vq_norm = None |
| args.num_quantizers = 6 |
|
|
| net = RVQVAE(args, |
| 263, |
| args.nb_code, |
| args.code_dim, |
| args.output_emb_width, |
| args.down_t, |
| args.stride_t, |
| args.width, |
| args.depth, |
| args.dilation_growth_rate, |
| args.vq_act, |
| args.vq_norm) |
|
|
| |
| vqvae_ckpt = os.path.join(momask_root, "checkpoints", "t2m", "Comp_v6_KLD005", |
| "net_last.pth") |
| ckpt = torch.load(vqvae_ckpt, map_location="cpu") |
| net.load_state_dict(ckpt["net"], strict=True) |
| net.eval().to(device) |
|
|
| |
| trans_ckpt_dir = os.path.join(momask_root, "checkpoints", "t2m", |
| "t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns") |
| trans = MaskTransformer(code_dim=opt.code_dim, |
| cond_mode="text", |
| latent_dim=opt.latent_dim, |
| ff_size=opt.ff_size, |
| num_layers=opt.num_layers, |
| num_heads=opt.num_heads, |
| dropout=opt.dropout, |
| clip_dim=512, |
| cond_drop_prob=opt.cond_drop_prob, |
| clip_version=opt.clip_version, |
| opt=opt) |
| trans_ckpt = torch.load(os.path.join(trans_ckpt_dir, "net_last.pth"), map_location="cpu") |
| trans.load_state_dict(trans_ckpt["trans"], strict=True) |
| trans.eval().to(device) |
|
|
| _model = (net, trans, opt, device) |
| print(f"[momask_server] Model loaded on {device}") |
|
|
|
|
| def _generate(prompt: str, num_frames: int, seed: int) -> np.ndarray: |
| """Run MoMask inference; return denormalised [T, 263] array.""" |
| import torch |
| from utils.motion_process import recover_from_ric |
|
|
| net, trans, opt, device = _model |
|
|
| if seed >= 0: |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
|
|
| T = min(int(num_frames), _max_len) |
|
|
| with torch.no_grad(): |
| |
| from models.mask_transformer.transformer import MaskTransformer |
| cond_vector = trans.encode_text([prompt]) |
|
|
| |
| mids = trans.generate(cond_vector, T // 4, temperature=1.0, topk_filter_thres=0.9, |
| gsample=True, force_mask=False) |
|
|
| |
| motion = net.forward_decoder(mids) |
| motion = motion[0].cpu().numpy() |
|
|
| |
| motion = motion * _std + _mean |
| return motion.astype(np.float32) |
|
|
|
|
| |
|
|
| @app.route("/health", methods=["GET"]) |
| def health(): |
| return jsonify({"status": "ok", "model_loaded": _model is not None}) |
|
|
|
|
| @app.route("/generate", methods=["POST"]) |
| def generate(): |
| body = request.get_json(force=True) |
| prompt = body.get("prompt", "a person walks forward") |
| num_frames = int(body.get("num_frames", 120)) |
| seed = int(body.get("seed", -1)) |
|
|
| if _model is None: |
| return jsonify({"error": "model not loaded"}), 503 |
|
|
| try: |
| motion = _generate(prompt, num_frames, seed) |
| return jsonify({ |
| "motion": motion.tolist(), |
| "num_frames": int(motion.shape[0]), |
| "fps": 20, |
| "prompt": prompt, |
| }) |
| except Exception as e: |
| return jsonify({"error": str(e)}), 500 |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--momask-root", default="/root/momask-codes") |
| parser.add_argument("--port", type=int, default=8765) |
| parser.add_argument("--device", default="cuda") |
| parser.add_argument("--host", default="0.0.0.0") |
| args = parser.parse_args() |
|
|
| print(f"[momask_server] Loading model from {args.momask_root} ...") |
| _load_model(args.momask_root, args.device) |
|
|
| print(f"[momask_server] Listening on {args.host}:{args.port}") |
| app.run(host=args.host, port=args.port, threaded=False) |
|
|