File size: 5,770 Bytes
8f1bcd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""
search.py
Stream TeoGchx/HumanML3D from HuggingFace and match motions by keyword.

Dataset: https://huggingface.co/datasets/TeoGchx/HumanML3D
  Format: motion column is [T, 263] inline in parquet (standard HumanML3D)
  Splits: train (23 384), val (1 460), test (4 384)

Usage
-----
    from Retarget.search import search_motions

    results = search_motions("a person walks forward", top_k=5)
    for r in results:
        print(r["caption"], r["frames"], "frames")
        # r["motion"]  β†’  np.ndarray [T, 263]
"""
from __future__ import annotations
import re
from typing import List, Optional
import numpy as np


# ─────────────────────────────────────────────────────────────────────────────
# Caption cleaning
# ─────────────────────────────────────────────────────────────────────────────

_SEP = re.compile(r'#|\|')
_POS_TAG = re.compile(r'^(?:[A-Z]{1,4}\s*)+$')   # lines that look like POS tags


def _clean_caption(raw: str) -> str:
    """
    HumanML3D captions are stored as multiple sentences joined by '#',
    sometimes followed by POS tag strings.  Return the first human-readable
    sentence.
    """
    parts = _SEP.split(raw)
    for part in parts:
        part = part.strip()
        if not part:
            continue
        words = part.split()
        # Skip if >50 % of tokens look like POS tags (all-caps, ≀4 chars)
        pos_count = sum(1 for w in words if w.isupper() and len(w) <= 4)
        if len(words) > 0 and pos_count / len(words) < 0.5:
            return part
    return parts[0].strip() if parts else raw.strip()


# ─────────────────────────────────────────────────────────────────────────────
# Search
# ─────────────────────────────────────────────────────────────────────────────

def search_motions(
    query: str,
    top_k: int = 8,
    split: str = "test",
    max_scan: int = 4384,
    cached: bool = False,
) -> List[dict]:
    """
    Stream TeoGchx/HumanML3D and return up to top_k motions matching query.

    Parameters
    ----------
    query       Natural-language description, e.g. "a person walks forward"
    top_k       Maximum number of results to return
    split       Dataset split β€” "test" (4 384 rows) is fastest to stream
    max_scan    Hard cap on rows examined before returning

    Returns
    -------
    List of dicts, sorted by relevance score (descending):
        caption  str           clean human-readable description
        motion   np.ndarray    shape [T, 263], standard HumanML3D features
        frames   int           number of frames (T)
        duration float         duration in seconds (at 20 fps)
        name     str           original clip ID from dataset
        score    int           keyword match score
    """
    try:
        from datasets import load_dataset
    except ImportError:
        raise ImportError(
            "pip install datasets   (HuggingFace datasets library required)"
        )

    if cached:
        # Downloads the split once (~400MB) and caches to ~/.cache/huggingface.
        # Subsequent calls are instant.  Use for local dev / testing.
        ds = load_dataset("TeoGchx/HumanML3D", split=split)
    else:
        # Streaming: no disk cache, re-downloads each run. Good for server use.
        ds = load_dataset("TeoGchx/HumanML3D", split=split, streaming=True)

    # Tokenise query; remove punctuation
    query_words = re.sub(r"[^\w\s]", "", query.lower()).split()
    if not query_words:
        return []

    results: List[dict] = []
    scanned = 0

    for row in ds:
        if scanned >= max_scan:
            break
        scanned += 1

        caption_raw = row.get("caption", "") or ""
        caption_clean = _clean_caption(caption_raw)
        caption_lower = caption_clean.lower()

        # Score: word-boundary matches count 2, substring matches count 1
        score = 0
        for kw in query_words:
            if kw in caption_lower:
                if re.search(r"\b" + re.escape(kw) + r"\b", caption_lower):
                    score += 2
                else:
                    score += 1

        if score == 0:
            continue

        motion_raw = row.get("motion")
        if motion_raw is None:
            continue

        motion = np.array(motion_raw, dtype=np.float32)  # [T, 263]
        meta   = row.get("meta_data") or {}

        T        = motion.shape[0]
        frames   = int(meta.get("num_frames", T))
        duration = float(meta.get("duration", T / 20.0))

        results.append({
            "caption":  caption_clean,
            "motion":   motion,
            "frames":   frames,
            "duration": duration,
            "name":     str(meta.get("name", "")),
            "score":    score,
        })

        # Stop as soon as we have top_k results
        if len(results) >= top_k:
            break

    results.sort(key=lambda x: -x["score"])
    return results[:top_k]


def format_choice_label(result: dict) -> str:
    """Short label for Gradio Radio component."""
    caption = result["caption"]
    if len(caption) > 72:
        caption = caption[:72] + "…"
    return f"{caption}  ({result['frames']} frames, {result['duration']:.1f}s)"