Image2Model / Retarget /search.py
Daankular's picture
Port MeshForge features to ZeroGPU Space: FireRed, PSHuman, Motion Search
8f1bcd9
"""
search.py
Stream TeoGchx/HumanML3D from HuggingFace and match motions by keyword.
Dataset: https://huggingface.co/datasets/TeoGchx/HumanML3D
Format: motion column is [T, 263] inline in parquet (standard HumanML3D)
Splits: train (23 384), val (1 460), test (4 384)
Usage
-----
from Retarget.search import search_motions
results = search_motions("a person walks forward", top_k=5)
for r in results:
print(r["caption"], r["frames"], "frames")
# r["motion"] β†’ np.ndarray [T, 263]
"""
from __future__ import annotations
import re
from typing import List, Optional
import numpy as np
# ─────────────────────────────────────────────────────────────────────────────
# Caption cleaning
# ─────────────────────────────────────────────────────────────────────────────
_SEP = re.compile(r'#|\|')
_POS_TAG = re.compile(r'^(?:[A-Z]{1,4}\s*)+$') # lines that look like POS tags
def _clean_caption(raw: str) -> str:
"""
HumanML3D captions are stored as multiple sentences joined by '#',
sometimes followed by POS tag strings. Return the first human-readable
sentence.
"""
parts = _SEP.split(raw)
for part in parts:
part = part.strip()
if not part:
continue
words = part.split()
# Skip if >50 % of tokens look like POS tags (all-caps, ≀4 chars)
pos_count = sum(1 for w in words if w.isupper() and len(w) <= 4)
if len(words) > 0 and pos_count / len(words) < 0.5:
return part
return parts[0].strip() if parts else raw.strip()
# ─────────────────────────────────────────────────────────────────────────────
# Search
# ─────────────────────────────────────────────────────────────────────────────
def search_motions(
query: str,
top_k: int = 8,
split: str = "test",
max_scan: int = 4384,
cached: bool = False,
) -> List[dict]:
"""
Stream TeoGchx/HumanML3D and return up to top_k motions matching query.
Parameters
----------
query Natural-language description, e.g. "a person walks forward"
top_k Maximum number of results to return
split Dataset split β€” "test" (4 384 rows) is fastest to stream
max_scan Hard cap on rows examined before returning
Returns
-------
List of dicts, sorted by relevance score (descending):
caption str clean human-readable description
motion np.ndarray shape [T, 263], standard HumanML3D features
frames int number of frames (T)
duration float duration in seconds (at 20 fps)
name str original clip ID from dataset
score int keyword match score
"""
try:
from datasets import load_dataset
except ImportError:
raise ImportError(
"pip install datasets (HuggingFace datasets library required)"
)
if cached:
# Downloads the split once (~400MB) and caches to ~/.cache/huggingface.
# Subsequent calls are instant. Use for local dev / testing.
ds = load_dataset("TeoGchx/HumanML3D", split=split)
else:
# Streaming: no disk cache, re-downloads each run. Good for server use.
ds = load_dataset("TeoGchx/HumanML3D", split=split, streaming=True)
# Tokenise query; remove punctuation
query_words = re.sub(r"[^\w\s]", "", query.lower()).split()
if not query_words:
return []
results: List[dict] = []
scanned = 0
for row in ds:
if scanned >= max_scan:
break
scanned += 1
caption_raw = row.get("caption", "") or ""
caption_clean = _clean_caption(caption_raw)
caption_lower = caption_clean.lower()
# Score: word-boundary matches count 2, substring matches count 1
score = 0
for kw in query_words:
if kw in caption_lower:
if re.search(r"\b" + re.escape(kw) + r"\b", caption_lower):
score += 2
else:
score += 1
if score == 0:
continue
motion_raw = row.get("motion")
if motion_raw is None:
continue
motion = np.array(motion_raw, dtype=np.float32) # [T, 263]
meta = row.get("meta_data") or {}
T = motion.shape[0]
frames = int(meta.get("num_frames", T))
duration = float(meta.get("duration", T / 20.0))
results.append({
"caption": caption_clean,
"motion": motion,
"frames": frames,
"duration": duration,
"name": str(meta.get("name", "")),
"score": score,
})
# Stop as soon as we have top_k results
if len(results) >= top_k:
break
results.sort(key=lambda x: -x["score"])
return results[:top_k]
def format_choice_label(result: dict) -> str:
"""Short label for Gradio Radio component."""
caption = result["caption"]
if len(caption) > 72:
caption = caption[:72] + "…"
return f"{caption} ({result['frames']} frames, {result['duration']:.1f}s)"