Shihao Wang commited on
Commit
60ffd60
·
1 Parent(s): 0027999

Made-with: Cursor

.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.ttf filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
38
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
39
+ *.png filter=lfs diff=lfs merge=lfs -text
40
+ *.gif filter=lfs diff=lfs merge=lfs -text
41
+ *.webp filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,15 @@
1
  ---
2
- title: LocateAnything
3
- emoji: 🐨
4
- colorFrom: purple
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 6.9.0
8
  app_file: app.py
9
  pinned: false
10
- license: other
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Test
3
+ emoji: 💬
4
+ colorFrom: yellow
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 6.5.1
8
  app_file: app.py
9
  pinned: false
10
+ hf_oauth: true
11
+ hf_oauth_scopes:
12
+ - inference-api
13
  ---
14
 
15
+ An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
app.py CHANGED
@@ -1,7 +1,1233 @@
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
  import gradio as gr
4
+ import cv2
5
+ import numpy as np
6
+ import os
7
+ import tempfile
8
+ import re
9
+ import time
10
+ import base64
11
+ import gc
12
+ import io
13
+ import json
14
+ import uuid
15
+ from pathlib import Path
16
 
17
+ import torch
18
+ from PIL import Image, ImageDraw, ImageFont
19
+ from transformers import AutoProcessor, AutoModel, AutoTokenizer
20
+ from huggingface_hub import CommitScheduler
21
 
22
+ import spaces
23
+
24
+ _FONT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "LXGWWenKai-Bold.ttf")
25
+
26
+
27
+ def _load_font(size=20):
28
+ """加载中文字体(LXGW WenKai),需提前放置到 assets/ 目录"""
29
+ if os.path.exists(_FONT_PATH):
30
+ try:
31
+ return ImageFont.truetype(_FONT_PATH, size)
32
+ except Exception:
33
+ pass
34
+ try:
35
+ return ImageFont.truetype("DejaVuSans-Bold.ttf", size)
36
+ except Exception:
37
+ return ImageFont.load_default()
38
+
39
+
40
+ # ============================================================
41
+ # 颜色 / 解析 / 绘制
42
+ # ============================================================
43
+ def get_color_for_label(label):
44
+ colors = [
45
+ (8, 145, 178), (220, 38, 38), (22, 163, 74), (37, 99, 235),
46
+ (217, 119, 6), (147, 51, 234),
47
+ ]
48
+ idx = sum(ord(c) for c in label)
49
+ return colors[idx % len(colors)]
50
+
51
+
52
+ def parse_mixed_results(text, category_str=""):
53
+ results = []
54
+ expected_cats = [c.strip().lower() for c in category_str.split("</c>") if c.strip()]
55
+
56
+ ref_box_pattern = r"(<ref>.*?</ref>)|(<box>.*?</box>)"
57
+ current_label = None
58
+ found_structured = False
59
+
60
+ for m in re.finditer(ref_box_pattern, text, flags=re.IGNORECASE | re.DOTALL):
61
+ token = m.group(0)
62
+ if token.lower().startswith("<ref>"):
63
+ label_raw = re.sub(r"</?ref>", "", token, flags=re.IGNORECASE).strip()
64
+ if label_raw:
65
+ current_label = label_raw
66
+ else:
67
+ content = re.sub(r"</?box>", "", token, flags=re.IGNORECASE)
68
+ nums = re.findall(r"<\s*([0-9]+(?:\.[0-9]+)?)\s*>", content)
69
+ coords = [float(n) for n in nums]
70
+ if not coords:
71
+ continue
72
+ label = current_label
73
+ if label is None:
74
+ label = expected_cats[0] if expected_cats else "object"
75
+ if len(coords) == 4:
76
+ results.append({"type": "box", "coords": coords, "label": label})
77
+ elif len(coords) == 2:
78
+ results.append({"type": "point", "coords": coords, "label": label})
79
+ found_structured = True
80
+
81
+ if found_structured:
82
+ return results
83
+
84
+ box_pattern = r"<box>(.*?)</box>"
85
+ parts = re.split(box_pattern, text)
86
+ for i in range(1, len(parts), 2):
87
+ preceding_text = parts[i - 1].lower()
88
+ content = parts[i]
89
+ label = expected_cats[0] if expected_cats else "object"
90
+ for cat in expected_cats:
91
+ if cat in preceding_text:
92
+ label = cat
93
+ break
94
+ nums = re.findall(r"<\s*([0-9]+(?:\.[0-9]+)?)\s*>", content)
95
+ coords = [float(n) for n in nums]
96
+ if len(coords) == 4:
97
+ results.append({"type": "box", "coords": coords, "label": label})
98
+ elif len(coords) == 2:
99
+ results.append({"type": "point", "coords": coords, "label": label})
100
+
101
+ return results
102
+
103
+
104
+ def resize_image_short_side(image, short_side_size):
105
+ w, h = image.size
106
+ if w <= h:
107
+ new_w = short_side_size
108
+ scale_factor = new_w / w
109
+ new_h = int(h * scale_factor)
110
+ else:
111
+ new_h = short_side_size
112
+ scale_factor = new_h / h
113
+ new_w = int(w * scale_factor)
114
+ resized_image = image.resize((new_w, new_h), Image.BILINEAR)
115
+ return resized_image, scale_factor
116
+
117
+
118
+ def draw_on_frame(frame_bgr, results, draw_label=True):
119
+ pil_img = Image.fromarray(cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB))
120
+ img_draw = pil_img.convert("RGBA")
121
+ overlay = Image.new("RGBA", img_draw.size, (255, 255, 255, 0))
122
+ draw = ImageDraw.Draw(overlay)
123
+ font = _load_font(20)
124
+ w_img, h_img = pil_img.size
125
+
126
+ parsed = []
127
+ for res in results:
128
+ label = res.get("label", "object")
129
+ color = get_color_for_label(label)
130
+ if res.get("type") == "point":
131
+ c = res["coords"]
132
+ cx = max(0, min(w_img, c[0] * w_img / 1000))
133
+ cy = max(0, min(h_img, c[1] * h_img / 1000))
134
+ parsed.append(("point", label, color, cx, cy))
135
+ continue
136
+ if "is_pixel" in res:
137
+ x1, y1, bw, bh = res["coords"]
138
+ x2, y2 = x1 + bw, y1 + bh
139
+ else:
140
+ c = res["coords"]
141
+ if len(c) < 4:
142
+ continue
143
+ x1 = c[0] * w_img / 1000
144
+ y1 = c[1] * h_img / 1000
145
+ x2 = c[2] * w_img / 1000
146
+ y2 = c[3] * h_img / 1000
147
+ x1, y1, x2, y2 = max(0, x1), max(0, y1), min(w_img, x2), min(h_img, y2)
148
+ x1, x2 = min(x1, x2), max(x1, x2)
149
+ y1, y2 = min(y1, y2), max(y1, y2)
150
+ parsed.append(("box", label, color, x1, y1, x2, y2))
151
+
152
+ for item in parsed:
153
+ if item[0] == "box":
154
+ _, _, color, x1, y1, x2, y2 = item
155
+ fill_color = color + (65,)
156
+ draw.rectangle([x1, y1, x2, y2], fill=fill_color, outline=color, width=4)
157
+ elif item[0] == "point":
158
+ _, _, color, cx, cy = item
159
+ r = 10
160
+ draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=color, outline="white", width=2)
161
+
162
+ if draw_label:
163
+ for item in parsed:
164
+ if item[0] == "box":
165
+ _, label, color, x1, y1, x2, y2 = item
166
+ if not label:
167
+ continue
168
+ t_box = draw.textbbox((0, 0), label, font=font)
169
+ th = t_box[3] - t_box[1]
170
+ tw = t_box[2] - t_box[0]
171
+ pad_x, pad_y = 7, 4
172
+ tag_h = th + pad_y * 2
173
+ tag_w = tw + pad_x * 2
174
+ tag_y = y1 - tag_h - 2
175
+ if tag_y < 0:
176
+ tag_y = y2 + 2
177
+ draw.rectangle([x1, tag_y, x1 + tag_w, tag_y + tag_h], fill=color)
178
+ draw.text((x1 + pad_x, tag_y + pad_y), label, fill="white", font=font)
179
+ elif item[0] == "point":
180
+ _, label, color, cx, cy = item
181
+ if not label:
182
+ continue
183
+ t_box = draw.textbbox((0, 0), label, font=font)
184
+ th, tw = t_box[3] - t_box[1], t_box[2] - t_box[0]
185
+ tx, ty = cx + 14, cy - th // 2
186
+ draw.rectangle([tx - 2, ty - 2, tx + tw + 6, ty + th + 4], fill=color)
187
+ draw.text((tx + 2, ty), label, fill="white", font=font)
188
+
189
+ combined = Image.alpha_composite(img_draw, overlay).convert("RGB")
190
+ return cv2.cvtColor(np.array(combined), cv2.COLOR_RGB2BGR)
191
+
192
+
193
+ # ============================================================
194
+ # 模型
195
+ # ============================================================
196
+ class EagleWorker:
197
+ def __init__(self, model_path, device="cuda", generation_mode: str = "hybrid"):
198
+ self.model_id = model_path
199
+ self.device = device
200
+ self.dtype = torch.bfloat16
201
+ self.generation_mode = generation_mode
202
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
203
+ self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
204
+ self.model = AutoModel.from_pretrained(
205
+ model_path, torch_dtype=self.dtype,
206
+ _attn_implementation="sdpa", trust_remote_code=True,
207
+ ).to(device).eval()
208
+ print("Model Loaded Successfully!")
209
+
210
+ def build_messages(self, image, categories, question_override=None):
211
+ if question_override is not None:
212
+ user_text = question_override
213
+ else:
214
+ category_set_str = "</c>".join(categories)
215
+ user_text = f"Locate all the instances that matches the following description: {category_set_str}."
216
+ return [{"role": "user", "content": [
217
+ {"type": "image", "image": image},
218
+ {"type": "text", "text": user_text},
219
+ ]}]
220
+
221
+ @torch.no_grad()
222
+ def generate(self, image, categories, generation_mode=None,
223
+ max_new_tokens=4096, temp=0.7, top_p=0.9, top_k=50,
224
+ question_override=None):
225
+ messages = self.build_messages(image, categories, question_override=question_override)
226
+ text = self.processor.py_apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
227
+ images, videos = self.processor.process_vision_info(messages)
228
+ inputs = self.processor(text=[text], images=images, videos=videos, return_tensors="pt").to(self.device)
229
+
230
+ pixel_values = inputs["pixel_values"].to(self.dtype)
231
+ input_ids = inputs["input_ids"]
232
+ attention_mask = inputs["attention_mask"]
233
+ image_grid_hws = inputs.get("image_grid_hws", None)
234
+
235
+ result = self.model.generate(
236
+ pixel_values=pixel_values, input_ids=input_ids,
237
+ attention_mask=attention_mask, image_grid_hws=image_grid_hws,
238
+ tokenizer=self.tokenizer, max_new_tokens=max_new_tokens,
239
+ use_cache=True,
240
+ generation_mode=generation_mode if generation_mode is not None else self.generation_mode,
241
+ temperature=temp, do_sample=True, top_p=top_p,
242
+ repetition_penalty=1.1, verbose=True,
243
+ )
244
+
245
+ token_sequence, out_info, output_text = [], "", ""
246
+ if isinstance(result, tuple) and len(result) >= 3:
247
+ output_text, token_sequence, out_info = result
248
+ if generation_mode == "slow":
249
+ token_sequence[-1] = ("ar", token_sequence[-1][1])
250
+ else:
251
+ output_text = result
252
+ return output_text, token_sequence, out_info
253
+
254
+
255
+ # ============================================================
256
+ # 后处理 / HTML
257
+ # ============================================================
258
+ def _postprocess_detections(detections, w, h):
259
+ valid = []
260
+ for det in detections:
261
+ if det["type"] == "box":
262
+ c = det["coords"]
263
+ rx1 = max(0, min(w - 1, int(c[0] * w / 1000)))
264
+ ry1 = max(0, min(h - 1, int(c[1] * h / 1000)))
265
+ rx2 = max(0, min(w - 1, int(c[2] * w / 1000)))
266
+ ry2 = max(0, min(h - 1, int(c[3] * h / 1000)))
267
+ box_w, box_h = rx2 - rx1, ry2 - ry1
268
+ if box_w <= 0 or box_h <= 0:
269
+ continue
270
+ valid.append({"type": "box", "coords": [rx1, ry1, box_w, box_h],
271
+ "is_pixel": True, "label": det["label"]})
272
+ elif det["type"] == "point":
273
+ valid.append(det)
274
+ return valid
275
+
276
+
277
+ def _parse_out_info_dict(out_info: str) -> dict:
278
+ stats = {}
279
+ if not out_info:
280
+ return stats
281
+ cleaned = re.sub(r"^[Ss]tast?ic\s*[Ii]nfo\s*,?\s*", "", out_info.strip())
282
+ for part in cleaned.split(";"):
283
+ part = part.strip()
284
+ if "=" in part:
285
+ k, v = part.split("=", 1)
286
+ stats[k.strip()] = v.strip()
287
+ return stats
288
+
289
+
290
+ def generate_dynamic_html(token_sequence, out_info, raw_text):
291
+ uid = f"a{int(time.time() * 1000)}"
292
+ css = f"""
293
+ <style>
294
+ .dc-root {{
295
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
296
+ border: 1px solid #cce875; border-radius: 10px; background: #ffffff; overflow: hidden;
297
+ }}
298
+ .dc-header {{
299
+ display: flex; align-items: center; justify-content: space-between;
300
+ padding: 12px 18px;
301
+ background: linear-gradient(135deg, #76b900 0%, #649d00 100%);
302
+ border-bottom: 1px solid #527f00;
303
+ }}
304
+ .dc-header-title {{ font-weight: 700; font-size: 0.95em; color: #ffffff !important; letter-spacing: 0.3px; }}
305
+ .dc-legend {{ display: flex; gap: 16px; align-items: center; }}
306
+ .dc-legend-item {{ display: flex; align-items: center; gap: 5px; font-size: 0.78em; color: rgba(255,255,255,0.92); font-weight: 500; }}
307
+ .dc-legend-dot {{ width: 10px; height: 10px; border-radius: 3px; display: inline-block; border: 1px solid rgba(255,255,255,0.5); }}
308
+ .dc-row {{ display: flex; gap: 10px; padding: 14px 18px; border-bottom: 1px solid #eef7d1; }}
309
+ .dc-row:last-child {{ border-bottom: none; }}
310
+ .dc-val {{ flex: 1; line-height: 2.3; word-wrap: break-word; color: #4b5563; font-size: 0.92em; }}
311
+ @keyframes tk-{uid} {{
312
+ 0% {{ opacity: 0; transform: translateY(8px) scale(0.92); }}
313
+ 60% {{ opacity: 1; transform: translateY(-2px) scale(1.02); }}
314
+ 100% {{ opacity: 1; transform: translateY(0) scale(1); }}
315
+ }}
316
+ .tk-mtp-{uid}, .tk-ar-{uid} {{
317
+ opacity: 0; animation: tk-{uid} 0.35s ease-out forwards;
318
+ border-radius: 5px; padding: 2px 7px; margin: 2px 1px; display: inline-block;
319
+ font-size: 0.80em; font-weight: 600;
320
+ font-family: 'SFMono-Regular', Consolas, 'Courier New', monospace; white-space: nowrap;
321
+ }}
322
+ .tk-mtp-{uid} {{ background: #e8f5e9; border: 2px solid #76b900; color: #2d4400; box-shadow: 0 1px 2px rgba(118,185,0,0.15); }}
323
+ .tk-ar-{uid} {{ background: #fff3e0; border: 2px solid #e65100; color: #bf360c; box-shadow: 0 1px 2px rgba(230,81,0,0.15); }}
324
+ .tk-stat-{uid} {{
325
+ opacity: 0; animation: tk-{uid} 0.4s ease-out forwards;
326
+ background: #f0f9e2; border: 1px solid #a4d422; border-radius: 6px;
327
+ padding: 5px 14px; display: inline-block; font-size: 0.82em; color: #3f6200; font-weight: 600;
328
+ }}
329
+ .dc-raw {{ padding: 0 18px 14px; }}
330
+ .dc-raw summary {{ cursor: pointer; color: #9ca3af; font-size: 0.82em; user-select: none; transition: color .15s; }}
331
+ .dc-raw summary:hover {{ color: #649d00; }}
332
+ .dc-raw-pre {{
333
+ background: #f7fbe8; border: 1px solid #ddf0a3; border-radius: 6px;
334
+ padding: 12px; margin-top: 8px;
335
+ font-family: 'SFMono-Regular', Consolas, 'Courier New', monospace;
336
+ font-size: 0.78em; color: #374151; white-space: pre-wrap; word-break: break-all;
337
+ max-height: 200px; overflow-y: auto;
338
+ }}
339
+ @media (max-width: 640px) {{
340
+ .dc-header {{ flex-direction: column; gap: 8px; align-items: flex-start; }}
341
+ .dc-row {{ flex-direction: column; gap: 4px; }}
342
+ }}
343
+ </style>
344
+ """
345
+ h = css + '<div class="dc-root">'
346
+ h += ('<div class="dc-header">'
347
+ '<span class="dc-header-title">LocateAnything Decoding Trace</span>'
348
+ '<div class="dc-legend">'
349
+ '<div class="dc-legend-item"><span class="dc-legend-dot" style="background:#76b900;"></span>MTP &mdash; Parallel Box Decoding</div>'
350
+ '<div class="dc-legend-item"><span class="dc-legend-dot" style="background:#e65100;"></span>AR &mdash; NTP Fallback (Re-decoding)</div>'
351
+ '</div></div>')
352
+ h += '<div class="dc-row"><div class="dc-val">'
353
+ tok_idx = 0
354
+ if token_sequence:
355
+ for item in token_sequence:
356
+ if not isinstance(item, (list, tuple)) or len(item) < 2:
357
+ continue
358
+ decode_type = str(item[0]).lower()
359
+ text = str(item[1])
360
+ safe = text.replace("<", "&lt;").replace(">", "&gt;")
361
+ delay = f"{tok_idx * 0.06:.2f}s"
362
+ cls = f"tk-ar-{uid}" if decode_type == "ar" else f"tk-mtp-{uid}"
363
+ h += f'<span class="{cls}" style="animation-delay:{delay}">{safe}</span> '
364
+ tok_idx += 1
365
+ h += '</div></div>'
366
+ if out_info:
367
+ stats = _parse_out_info_dict(out_info)
368
+ bits = []
369
+ if "forward_step" in stats: bits.append(f"{stats['forward_step']} steps")
370
+ if "num_tokens" in stats: bits.append(f"{stats['num_tokens']} tokens")
371
+ if "num_boxes" in stats: bits.append(f"{stats['num_boxes']} boxes")
372
+ if "switch_to_ar" in stats:
373
+ n = stats["switch_to_ar"]
374
+ bits.append(f"{n} AR Fallback{'s' if n != '1' else ''}")
375
+ if "ar_step" in stats: bits.append(f"{stats['ar_step']} AR steps")
376
+ if "tps" in stats: bits.append(f"{stats['tps']} tok/s")
377
+ if "bps" in stats: bits.append(f"{stats['bps']} box/s")
378
+ summary = " &middot; ".join(bits) if bits else out_info.strip()
379
+ stat_delay = f"{tok_idx * 0.06 + 0.3:.2f}s"
380
+ h += (f'<div class="dc-row" style="justify-content:flex-end;padding-top:4px;padding-bottom:10px;border-bottom:none;">'
381
+ f'<span class="tk-stat-{uid}" style="animation-delay:{stat_delay}">⚡ {summary}</span></div>')
382
+ if raw_text:
383
+ safe_raw = raw_text.replace("<", "&lt;").replace(">", "&gt;")
384
+ h += (f'<div class="dc-raw"><details><summary>📄 Show Raw Response</summary>'
385
+ f'<div class="dc-raw-pre">{safe_raw}</div></details></div>')
386
+ h += '</div>'
387
+ return h
388
+
389
+
390
+ def generate_raw_prompt(task_type, category):
391
+ if not category:
392
+ category = "objects"
393
+ cats = "</c>".join(c.strip() for c in category.split(",") if c.strip())
394
+ if task_type == "Detection":
395
+ return f"Locate all the instances that matches the following description: {cats}."
396
+ elif task_type == "Grounding":
397
+ return f"Locate all the instances that match the following description: {cats}."
398
+ elif task_type == "OCR":
399
+ return "Detect all the text in box format."
400
+ elif task_type == "GUI":
401
+ return f"Locate the region that matches the following description: {cats}."
402
+ elif task_type == "Pointing":
403
+ return f"Point to: {cats}."
404
+ else:
405
+ return f"Locate all the instances that matches the following description: {cats}."
406
+
407
+
408
+ # ============================================================
409
+ # 模型初始化
410
+ # ============================================================
411
+ try:
412
+ MODEL_PATH = os.environ.get("MODEL_PATH", "woshichaoren123/test001")
413
+ GLOBAL_WORKER = EagleWorker(MODEL_PATH)
414
+ except Exception as e:
415
+ print(f"Failed to load model: {e}. Will run in Mock Mode.")
416
+ GLOBAL_WORKER = None
417
+
418
+
419
+ # ============================================================
420
+ # 用户数据收集(HuggingFace Public Dataset)
421
+ # ============================================================
422
+ LOG_DATASET_REPO = os.environ.get("LOG_DATASET_REPO", "woshichaoren123/log")
423
+ LOG_HF_TOKEN = os.environ.get("LOG_HF_TOKEN")
424
+ _LOG_DIR = Path(tempfile.mkdtemp(prefix="hf_log_"))
425
+ _log_scheduler = None
426
+
427
+ if LOG_DATASET_REPO and LOG_HF_TOKEN:
428
+ _log_scheduler = CommitScheduler(
429
+ repo_id=LOG_DATASET_REPO,
430
+ repo_type="dataset",
431
+ folder_path=str(_LOG_DIR),
432
+ path_in_repo="data",
433
+ every=5,
434
+ token=LOG_HF_TOKEN,
435
+ )
436
+ print(f"[LOG] Dataset logging enabled → {LOG_DATASET_REPO}")
437
+ else:
438
+ print("[LOG] Dataset logging disabled (LOG_HF_TOKEN not set)")
439
+
440
+
441
+ def _pil_to_b64(pil_img):
442
+ """将 PIL 图片无损转为 PNG base64 字符串。"""
443
+ buf = io.BytesIO()
444
+ pil_img.save(buf, "PNG")
445
+ return base64.b64encode(buf.getvalue()).decode("ascii")
446
+
447
+
448
+ def _log_to_dataset(
449
+ input_type, category, model_mode, raw_prompt,
450
+ output_text="", input_image=None, output_image=None,
451
+ extra=None,
452
+ ):
453
+ """将用户 query、输入图片(base64)、推理结果写入按天分片的 JSONL。"""
454
+ if _log_scheduler is None:
455
+ return
456
+ try:
457
+ entry_id = f"{int(time.time())}_{uuid.uuid4().hex[:6]}"
458
+ ts = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
459
+ date_str = time.strftime("%Y-%m-%d", time.gmtime())
460
+
461
+ input_b64 = None
462
+ if input_image is not None and isinstance(input_image, Image.Image):
463
+ input_b64 = _pil_to_b64(input_image)
464
+
465
+ output_b64 = None
466
+ if output_image is not None and isinstance(output_image, Image.Image):
467
+ output_b64 = _pil_to_b64(output_image)
468
+
469
+ record = {
470
+ "id": entry_id,
471
+ "timestamp": ts,
472
+ "input_type": input_type,
473
+ "category": category,
474
+ "model_mode": model_mode,
475
+ "raw_prompt": raw_prompt,
476
+ "output_text": output_text,
477
+ "input_image_b64": input_b64,
478
+ "output_image_b64": output_b64,
479
+ }
480
+ if extra:
481
+ record.update(extra)
482
+
483
+ log_file = _LOG_DIR / f"logs_{date_str}.jsonl"
484
+ with _log_scheduler.lock:
485
+ with open(log_file, "a", encoding="utf-8") as f:
486
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
487
+ except Exception as e:
488
+ print(f"[LOG] Failed to log to dataset: {e}")
489
+
490
+
491
+ # ============================================================
492
+ # 公用预处理
493
+ # ============================================================
494
+ def _prepare_image_for_model(pil_img, short_size):
495
+ process_img = pil_img.copy()
496
+ if short_size is not None and short_size > 0:
497
+ process_img, _ = resize_image_short_side(process_img, min(int(short_size), 1024))
498
+ else:
499
+ if min(process_img.size) > 1024:
500
+ process_img, _ = resize_image_short_side(process_img, 1024)
501
+ return process_img
502
+
503
+
504
+ # ============================================================
505
+ # GPU 时间预算常量(按模式区分)
506
+ # ============================================================
507
+ GPU_HARD_LIMIT_IMAGE = 30 # Image 模式 @spaces.GPU(duration=...)
508
+ GPU_HARD_LIMIT_VIDEO = 240 # Video 模式 @spaces.GPU(duration=...)
509
+ PHASE2_RESERVE = 55 # 留给 Phase 2(绘制 + ffmpeg)的秒数
510
+ SAFETY_MARGIN = 25 # 额外安全裕量,永远不要触碰硬限制
511
+ INFERENCE_BUDGET = GPU_HARD_LIMIT_VIDEO - PHASE2_RESERVE - SAFETY_MARGIN
512
+ EST_SECONDS_PER_FRAME = 20 # 保守估计:每帧推理耗时
513
+
514
+
515
+ # ============================================================
516
+ # ✅ 图像推理(独立函数)
517
+ # ============================================================
518
+ def _run_image_inference(
519
+ image_in, categories_list, category_str,
520
+ model_mode, temp, top_p, top_k, short_size, question_override,
521
+ progress=None, # 接收 progress
522
+ ):
523
+ if image_in is None:
524
+ return (
525
+ gr.update(value=None, visible=True),
526
+ gr.update(value=None, visible=False),
527
+ "<p style='color:#ef4444;padding:12px;'>⚠️ Please upload an image first.</p>",
528
+ )
529
+
530
+ if progress is not None: # 进度提示
531
+ progress(0.1, desc="Preprocessing image ...")
532
+
533
+ process_img = _prepare_image_for_model(image_in, short_size)
534
+
535
+ if progress is not None:
536
+ progress(0.2, desc="Running model inference ...")
537
+
538
+ if GLOBAL_WORKER:
539
+ output_text, token_sequence, out_info = GLOBAL_WORKER.generate(
540
+ process_img, categories_list, model_mode,
541
+ temp=temp, top_p=top_p, top_k=top_k,
542
+ question_override=question_override,
543
+ )
544
+ else:
545
+ output_text, token_sequence, out_info = "", [], ""
546
+
547
+ if progress is not None:
548
+ progress(0.8, desc="Drawing results ...")
549
+
550
+ detections = parse_mixed_results(output_text, category_str)
551
+ frame_bgr = cv2.cvtColor(np.array(image_in), cv2.COLOR_RGB2BGR)
552
+ out_img_bgr = draw_on_frame(frame_bgr, detections, draw_label=True)
553
+ output_image = Image.fromarray(cv2.cvtColor(out_img_bgr, cv2.COLOR_BGR2RGB))
554
+ html = generate_dynamic_html(token_sequence, out_info, output_text)
555
+
556
+ _log_to_dataset(
557
+ input_type="image",
558
+ category=", ".join(categories_list),
559
+ model_mode=model_mode,
560
+ raw_prompt=question_override or category_str,
561
+ output_text=output_text,
562
+ input_image=image_in,
563
+ output_image=output_image,
564
+ )
565
+
566
+ if progress is not None:
567
+ progress(1.0, desc="Done!")
568
+
569
+ return (
570
+ gr.update(value=output_image, visible=True),
571
+ gr.update(value=None, visible=False),
572
+ html,
573
+ )
574
+
575
+
576
+ # ============================================================
577
+ # ✅ 视频推理(独立函数 — 带完整超时保护)
578
+ # ============================================================
579
+ def _run_video_inference(
580
+ video_in, categories_list, category_str,
581
+ model_mode, temp, top_p, top_k, short_size, question_override,
582
+ max_video_frames, # 可调帧数
583
+ progress=None, # 接收 progress
584
+ ):
585
+ import subprocess as _sp
586
+
587
+ if video_in is None:
588
+ return (
589
+ gr.update(value=None, visible=False),
590
+ gr.update(value=None, visible=True),
591
+ "<p style='color:#ef4444;padding:12px;'>⚠️ Please upload a video first.</p>",
592
+ )
593
+
594
+ total_start = time.time()
595
+ max_frames = int(max_video_frames) if max_video_frames else 4
596
+
597
+ if progress is not None:
598
+ progress(0.0, desc="Reading video ...")
599
+
600
+ # ---------- 读取视频 ----------
601
+ t0 = time.time()
602
+ cap = cv2.VideoCapture(video_in)
603
+ fps = cap.get(cv2.CAP_PROP_FPS)
604
+ vid_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
605
+ vid_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
606
+
607
+ all_frames = []
608
+ while cap.isOpened():
609
+ ret, frame = cap.read()
610
+ if not ret:
611
+ break
612
+ all_frames.append(frame)
613
+ cap.release()
614
+ total = len(all_frames)
615
+ read_elapsed = time.time() - t0
616
+ print(f"[TIMING] Video read: {read_elapsed:.2f}s, total frames={total}, "
617
+ f"resolution={vid_w}x{vid_h}, fps={fps:.1f}")
618
+
619
+ if total == 0:
620
+ return (
621
+ gr.update(value=None, visible=False),
622
+ gr.update(value=None, visible=True),
623
+ "<p style='color:#ef4444;padding:12px;'>⚠️ Failed to read any frames from the video.</p>",
624
+ )
625
+
626
+ # ---------- 采样帧 ----------
627
+ if total <= max_frames:
628
+ sample_indices = list(range(total))
629
+ else:
630
+ sample_indices = [int(round(i * (total - 1) / (max_frames - 1)))
631
+ for i in range(max_frames)]
632
+
633
+ sampled_frames = [all_frames[i] for i in sample_indices]
634
+ n_sampled = len(sampled_frames)
635
+
636
+ # ============================================================
637
+ # 🛡️ 预估检查:在开跑前判断能不能在 GPU 时间预算内跑完
638
+ # ============================================================
639
+ time_already_used = time.time() - total_start
640
+ available_for_inference = GPU_HARD_LIMIT_VIDEO - time_already_used - PHASE2_RESERVE - SAFETY_MARGIN
641
+ estimated_inference_time = n_sampled * EST_SECONDS_PER_FRAME
642
+
643
+ if estimated_inference_time > available_for_inference:
644
+ # 尝试自动缩减帧数
645
+ max_feasible = max(0, int(available_for_inference // EST_SECONDS_PER_FRAME))
646
+ print(f"[PRE-CHECK] Estimated {estimated_inference_time:.0f}s > budget {available_for_inference:.0f}s, "
647
+ f"reducing from {n_sampled} to {max_feasible} frames")
648
+
649
+ if max_feasible < 1:
650
+ # 连 1 帧都跑不了,直接拒绝
651
+ del all_frames
652
+ gc.collect()
653
+ return (
654
+ gr.update(value=None, visible=False),
655
+ gr.update(value=None, visible=True),
656
+ "<div style='background:#fef2f2;border:1px solid #fca5a5;border-radius:8px;"
657
+ "padding:16px;margin:8px 0;'>"
658
+ "<p style='color:#dc2626;font-weight:700;font-size:1.05em;margin:0 0 8px;'>"
659
+ "⚠️ Video too large to process</p>"
660
+ f"<p style='color:#7f1d1d;margin:0;font-size:0.92em;'>"
661
+ f"This video has <b>{total}</b> frames. "
662
+ f"Even processing <b>1</b> sampled frame (~{EST_SECONDS_PER_FRAME}s) "
663
+ f"would exceed the <b>{GPU_HARD_LIMIT_VIDEO}s</b> GPU time limit.<br><br>"
664
+ "💡 <b>Suggestions:</b> use a shorter / lower-resolution video, "
665
+ "or switch to <b>Image</b> mode with a single frame screenshot.</p></div>",
666
+ )
667
+
668
+ # 用缩减后的帧数重新采样
669
+ if total <= max_feasible:
670
+ sample_indices = list(range(total))
671
+ else:
672
+ sample_indices = [int(round(i * (total - 1) / (max_feasible - 1)))
673
+ for i in range(max_feasible)]
674
+ sampled_frames = [all_frames[i] for i in sample_indices]
675
+ n_sampled = len(sampled_frames)
676
+
677
+ # 释放原始帧列表,节省内存
678
+ out_fps = max(1.0, n_sampled / (total / fps)) if fps > 0 else 5.0
679
+ del all_frames
680
+ gc.collect()
681
+
682
+ print(f"[TIMING] Sampled {n_sampled} frames, output fps: {out_fps:.2f}")
683
+
684
+ # ============================================================
685
+ # 阶段一:推理(逐帧检查剩余时间)
686
+ # ============================================================
687
+ print("=" * 60)
688
+ print("[PHASE 1] Starting model inference ...")
689
+ print("=" * 60)
690
+
691
+ inference_results = []
692
+ phase1_start = time.time()
693
+ processed_count = 0
694
+ early_stopped = False
695
+ early_stop_reason = ""
696
+
697
+ for i, frame in enumerate(sampled_frames):
698
+ # ---- 🛡️ 运行时时间检查:还够不够跑下一帧 + Phase 2?----
699
+ elapsed_since_start = time.time() - total_start
700
+ remaining_total = GPU_HARD_LIMIT_VIDEO - elapsed_since_start
701
+
702
+ if remaining_total < PHASE2_RESERVE + SAFETY_MARGIN:
703
+ early_stopped = True
704
+ early_stop_reason = (
705
+ f"GPU time budget is running out: "
706
+ f"{elapsed_since_start:.0f}s used, only {remaining_total:.0f}s left "
707
+ f"(need ≥{PHASE2_RESERVE}s for video encoding). "
708
+ f"Successfully processed {processed_count}/{n_sampled} frames."
709
+ )
710
+ print(f"[⏰ EARLY STOP] {early_stop_reason}")
711
+ break
712
+
713
+ if progress is not None:
714
+ progress(
715
+ (i / n_sampled) * 0.85,
716
+ desc=f"🧠 Inference: frame {i + 1}/{n_sampled} "
717
+ f"(⏱️ {elapsed_since_start:.0f}s / {GPU_HARD_LIMIT_VIDEO}s) ...",
718
+ )
719
+
720
+ frame_t0 = time.time()
721
+
722
+ # 预处理
723
+ prep_t0 = time.time()
724
+ pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
725
+ process_img = _prepare_image_for_model(pil_img, short_size)
726
+ prep_time = time.time() - prep_t0
727
+
728
+ # 推理
729
+ infer_t0 = time.time()
730
+ if GLOBAL_WORKER:
731
+ output_text, _, _ = GLOBAL_WORKER.generate(
732
+ process_img, categories_list, model_mode,
733
+ temp=temp, top_p=top_p, top_k=top_k,
734
+ question_override=question_override,
735
+ )
736
+ else:
737
+ output_text = ""
738
+ infer_time = time.time() - infer_t0
739
+
740
+ inference_results.append(output_text)
741
+ processed_count += 1
742
+
743
+ # 清理 GPU 缓存
744
+ cleanup_t0 = time.time()
745
+ if torch.cuda.is_available():
746
+ torch.cuda.empty_cache()
747
+ gc.collect()
748
+ cleanup_time = time.time() - cleanup_t0
749
+
750
+ total_frame_time = time.time() - frame_t0
751
+ print(f"[PHASE 1] Frame {i + 1}/{n_sampled} done: "
752
+ f"prep={prep_time:.2f}s, infer={infer_time:.2f}s, "
753
+ f"cleanup={cleanup_time:.2f}s, total={total_frame_time:.2f}s")
754
+ if torch.cuda.is_available():
755
+ allocated = torch.cuda.memory_allocated() / 1024**3
756
+ reserved = torch.cuda.memory_reserved() / 1024**3
757
+ print(f" GPU mem: allocated={allocated:.2f}GB, reserved={reserved:.2f}GB")
758
+
759
+ phase1_time = time.time() - phase1_start
760
+ print(f"[PHASE 1] COMPLETE: {phase1_time:.2f}s for {processed_count} frames "
761
+ f"({phase1_time / max(processed_count, 1):.2f}s/frame)")
762
+
763
+ # 如果 1 帧都没处理完,返回错误
764
+ if processed_count == 0:
765
+ return (
766
+ gr.update(value=None, visible=False),
767
+ gr.update(value=None, visible=True),
768
+ "<div style='background:#fef2f2;border:1px solid #fca5a5;border-radius:8px;"
769
+ "padding:16px;margin:8px 0;'>"
770
+ "<p style='color:#dc2626;font-weight:700;font-size:1.05em;margin:0 0 8px;'>"
771
+ "⚠️ Could not process any frames</p>"
772
+ "<p style='color:#7f1d1d;margin:0;font-size:0.92em;'>"
773
+ "The GPU time limit was reached before even one frame could be processed. "
774
+ "Please try a lower resolution video or use Image mode instead.</p></div>",
775
+ )
776
+
777
+ # 裁剪到实际处理过的帧
778
+ sampled_frames_for_draw = sampled_frames[:processed_count]
779
+ inference_results_for_draw = inference_results[:processed_count]
780
+
781
+ # ============================================================
782
+ # 阶段二:绘制 + 编码(只处理已推理完的帧)
783
+ # ============================================================
784
+ if progress is not None:
785
+ progress(0.88, desc="🎨 Drawing & encoding video ...")
786
+
787
+ print("=" * 60)
788
+ print(f"[PHASE 2] Drawing & video encoding ({processed_count} frames) ...")
789
+ print("=" * 60)
790
+
791
+ phase2_start = time.time()
792
+ tmp_raw = tempfile.mktemp(suffix=".raw.mp4")
793
+ out_video_path = tempfile.mktemp(suffix=".mp4")
794
+ out = cv2.VideoWriter(tmp_raw, cv2.VideoWriter_fourcc(*"mp4v"),
795
+ out_fps, (vid_w, vid_h))
796
+
797
+ for i, (frame, output_text) in enumerate(
798
+ zip(sampled_frames_for_draw, inference_results_for_draw)):
799
+ draw_t0 = time.time()
800
+ detections = parse_mixed_results(output_text, category_str)
801
+ valid_results = _postprocess_detections(detections, vid_w, vid_h)
802
+ frame_to_draw = draw_on_frame(frame, valid_results, draw_label=True)
803
+ out.write(frame_to_draw)
804
+ draw_time = time.time() - draw_t0
805
+ print(f"[PHASE 2] Frame {i + 1}/{processed_count}: "
806
+ f"draw={draw_time:.3f}s, det={len(valid_results)}")
807
+
808
+ out.release()
809
+ phase2_draw_time = time.time() - phase2_start
810
+
811
+ # ---- ffmpeg 重编码(如果还有时间的话) ----
812
+ elapsed_now = time.time() - total_start
813
+ remaining_now = GPU_HARD_LIMIT_VIDEO - elapsed_now
814
+
815
+ if progress is not None:
816
+ progress(0.95, desc="📦 Re-encoding with ffmpeg ...")
817
+
818
+ ffmpeg_t0 = time.time()
819
+ if remaining_now > 15:
820
+ # 还有时间,用 ffmpeg 重编码(兼容性更好)
821
+ try:
822
+ ffmpeg_timeout = max(10, int(remaining_now - 5))
823
+ _sp.run(
824
+ ["ffmpeg", "-y", "-i", tmp_raw, "-c:v", "libx264",
825
+ "-preset", "ultrafast", "-crf", "23", "-pix_fmt", "yuv420p",
826
+ "-movflags", "+faststart", out_video_path],
827
+ check=True, capture_output=True, timeout=ffmpeg_timeout,
828
+ )
829
+ os.remove(tmp_raw)
830
+ except Exception as ffmpeg_err:
831
+ print(f"[PHASE 2] ffmpeg failed or timed out: {ffmpeg_err}, using raw file")
832
+ if os.path.exists(tmp_raw):
833
+ os.replace(tmp_raw, out_video_path)
834
+ else:
835
+ # 时间不够了,直接用 mp4v 原始文件
836
+ os.replace(tmp_raw, out_video_path)
837
+ print("[PHASE 2] Skipped ffmpeg re-encoding due to time constraint")
838
+
839
+ ffmpeg_time = time.time() - ffmpeg_t0
840
+ total_time = time.time() - total_start
841
+
842
+ print("=" * 60)
843
+ print(f"[TOTAL] {total_time:.2f}s | inference={phase1_time:.2f}s "
844
+ f"draw={phase2_draw_time:.2f}s ffmpeg={ffmpeg_time:.2f}s "
845
+ f"frames_done={processed_count}/{n_sampled}")
846
+ print("=" * 60)
847
+
848
+ # ---- 构建结果 HTML ----
849
+ warning_html = ""
850
+ if early_stopped:
851
+ warning_html = (
852
+ "<div style='background:#fefce8;border:1px solid #fde047;border-radius:8px;"
853
+ "padding:14px;margin-bottom:12px;'>"
854
+ "<p style='color:#a16207;font-weight:700;font-size:1.02em;margin:0 0 6px;'>"
855
+ "⚡ Partial Result — Early Stop Due to GPU Time Limit</p>"
856
+ f"<p style='color:#854d0e;margin:0;font-size:0.9em;'>{early_stop_reason}</p>"
857
+ "<p style='color:#854d0e;margin:6px 0 0;font-size:0.88em;'>"
858
+ "💡 <b>Tip:</b> Reduce <b>Max Video Frames</b> slider or use a shorter video "
859
+ "to process all frames within the GPU budget.</p>"
860
+ "</div>"
861
+ )
862
+
863
+ timing_summary = (
864
+ f"Video: {total} total frames, sampled {n_sampled}, "
865
+ f"processed {processed_count} | "
866
+ f"Inference: {phase1_time:.1f}s ({phase1_time / max(processed_count, 1):.1f}s/frame) | "
867
+ f"Drawing: {phase2_draw_time:.1f}s | ffmpeg: {ffmpeg_time:.1f}s | "
868
+ f"Total: {total_time:.1f}s / {GPU_HARD_LIMIT_VIDEO}s budget"
869
+ )
870
+ html = warning_html + generate_dynamic_html(
871
+ token_sequence=[], out_info="", raw_text=timing_summary)
872
+
873
+ try:
874
+ thumb = Image.fromarray(
875
+ cv2.cvtColor(sampled_frames_for_draw[0], cv2.COLOR_BGR2RGB))
876
+ except Exception:
877
+ thumb = None
878
+ _log_to_dataset(
879
+ input_type="video",
880
+ category=", ".join(categories_list),
881
+ model_mode=model_mode,
882
+ raw_prompt=question_override or category_str,
883
+ output_text="\n---\n".join(inference_results_for_draw),
884
+ input_image=thumb,
885
+ extra={
886
+ "video_total_frames": total,
887
+ "video_sampled_frames": n_sampled,
888
+ "video_processed_frames": processed_count,
889
+ },
890
+ )
891
+
892
+ if progress is not None:
893
+ progress(1.0, desc="Done!")
894
+
895
+ return (
896
+ gr.update(value=None, visible=False),
897
+ gr.update(value=out_video_path, visible=True),
898
+ html,
899
+ )
900
+
901
+
902
+ # ============================================================
903
+ # 🛡️ 主入口:按模式分配不同 GPU 时长
904
+ # ============================================================
905
+
906
+ def _build_error_html(e, gpu_limit, input_type):
907
+ """统一的异常→友好 HTML 构建。"""
908
+ import traceback
909
+ traceback.print_exc()
910
+
911
+ error_type = type(e).__name__
912
+ error_msg = str(e)
913
+
914
+ is_timeout = ("timeout" in error_msg.lower()
915
+ or "timelimit" in error_msg.lower()
916
+ or "time limit" in error_msg.lower()
917
+ or "duration" in error_msg.lower())
918
+
919
+ if is_timeout:
920
+ detail = (
921
+ f"The GPU time limit ({gpu_limit}s) was exceeded before the result "
922
+ "could be fully assembled. This typically happens with large videos."
923
+ )
924
+ suggestion = (
925
+ "Please reduce <b>Max Video Frames</b>, use a shorter / smaller video, "
926
+ "or switch to <b>Image</b> mode."
927
+ )
928
+ else:
929
+ detail = f"{error_type}: {error_msg}"
930
+ suggestion = (
931
+ "If the problem persists, try reducing video size or "
932
+ "switching to Image mode."
933
+ )
934
+
935
+ error_html = (
936
+ "<div style='background:#fef2f2;border:1px solid #fca5a5;border-radius:8px;"
937
+ "padding:16px;margin:8px 0;'>"
938
+ "<p style='color:#dc2626;font-weight:700;font-size:1.05em;margin:0 0 8px;'>"
939
+ "⚠️ Processing interrupted</p>"
940
+ f"<p style='color:#7f1d1d;margin:0 0 8px;font-size:0.92em;'>{detail}</p>"
941
+ f"<p style='color:#7f1d1d;margin:0;font-size:0.88em;'>💡 {suggestion}</p>"
942
+ "</div>"
943
+ )
944
+
945
+ return (
946
+ gr.update(value=None, visible=(input_type == "Image")),
947
+ gr.update(value=None, visible=(input_type == "Video")),
948
+ error_html,
949
+ )
950
+
951
+
952
+ @spaces.GPU(duration=GPU_HARD_LIMIT_IMAGE)
953
+ def _run_image_gpu(
954
+ image_in, category, model_mode, temp, top_p, top_k,
955
+ short_size, question_override, progress,
956
+ ):
957
+ try:
958
+ categories_list = [c.strip() for c in category.split(",") if c.strip()]
959
+ category_str = "</c>".join(categories_list)
960
+ return _run_image_inference(
961
+ image_in, categories_list, category_str,
962
+ model_mode, temp, top_p, top_k, short_size, question_override,
963
+ progress=progress,
964
+ )
965
+ except Exception as e:
966
+ return _build_error_html(e, GPU_HARD_LIMIT_IMAGE, "Image")
967
+
968
+
969
+ @spaces.GPU(duration=GPU_HARD_LIMIT_VIDEO)
970
+ def _run_video_gpu(
971
+ video_in, category, model_mode, temp, top_p, top_k,
972
+ short_size, question_override, max_video_frames, progress,
973
+ ):
974
+ try:
975
+ categories_list = [c.strip() for c in category.split(",") if c.strip()]
976
+ category_str = "</c>".join(categories_list)
977
+ return _run_video_inference(
978
+ video_in, categories_list, category_str,
979
+ model_mode, temp, top_p, top_k, short_size, question_override,
980
+ max_video_frames=max_video_frames,
981
+ progress=progress,
982
+ )
983
+ except Exception as e:
984
+ return _build_error_html(e, GPU_HARD_LIMIT_VIDEO, "Video")
985
+
986
+
987
+ def run_inference(
988
+ input_type, image_in, video_in, task_type, category,
989
+ model_mode, temp, top_p, top_k, short_size, question_override,
990
+ max_video_frames,
991
+ progress=gr.Progress(track_tqdm=False),
992
+ ):
993
+ if input_type == "Image":
994
+ return _run_image_gpu(
995
+ image_in, category, model_mode, temp, top_p, top_k,
996
+ short_size, question_override, progress,
997
+ )
998
+ else:
999
+ return _run_video_gpu(
1000
+ video_in, category, model_mode, temp, top_p, top_k,
1001
+ short_size, question_override, max_video_frames, progress,
1002
+ )
1003
+
1004
+
1005
+ # ============================================================
1006
+ # 按钮状态
1007
+ # ============================================================
1008
+ def _disable_run_btn():
1009
+ return gr.update(interactive=False, value="⏳ Running ...")
1010
+
1011
+
1012
+ def _enable_run_btn():
1013
+ return gr.update(interactive=True, value="🧠 Run Inference")
1014
+
1015
+
1016
+ # ============================================================
1017
+ # Examples
1018
+ # ============================================================
1019
+ EXAMPLE_CONFIGS = [
1020
+ {"name": "Book", "input_type": "Image", "image": "./assets/book.jpg", "video": None,
1021
+ "task": "Detection", "category": "book", "mode": "hybrid"},
1022
+ {"name": "Sweet", "input_type": "Image", "image": "./assets/sweet.jpg", "video": None,
1023
+ "task": "Detection", "category": "sweet", "mode": "hybrid"},
1024
+ {"name": "Person", "input_type": "Image", "image": "./assets/person.jpg", "video": None,
1025
+ "task": "Detection", "category": "person", "mode": "hybrid"},
1026
+ {"name": "OCR", "input_type": "Image", "image": "./assets/ocr.jpg", "video": None,
1027
+ "task": "OCR", "category": "text", "mode": "fast"},
1028
+ ]
1029
+
1030
+
1031
+ def prepare_gallery_data():
1032
+ base_dir = os.path.dirname(os.path.abspath(__file__))
1033
+ gallery_images, gallery_captions = [], []
1034
+ for config in EXAMPLE_CONFIGS:
1035
+ img_path = (os.path.normpath(os.path.join(base_dir, config["image"]))
1036
+ if config["image"] else None)
1037
+ if img_path and os.path.exists(img_path):
1038
+ gallery_images.append(img_path)
1039
+ else:
1040
+ gallery_images.append(Image.new("RGB", (200, 200), color="black"))
1041
+ gallery_captions.append(config["name"])
1042
+ return gallery_images, gallery_captions
1043
+
1044
+
1045
+ def update_example_selection(evt: gr.SelectData):
1046
+ config = EXAMPLE_CONFIGS[evt.index]
1047
+ base_dir = os.path.dirname(os.path.abspath(__file__))
1048
+ img_path = (os.path.normpath(os.path.join(base_dir, config["image"]))
1049
+ if config["image"] else None)
1050
+ vid_path = (os.path.normpath(os.path.join(base_dir, config["video"]))
1051
+ if config["video"] else None)
1052
+ return (
1053
+ config["input_type"],
1054
+ gr.update(value=img_path, visible=(config["input_type"] == "Image")),
1055
+ gr.update(value=vid_path, visible=(config["input_type"] == "Video")),
1056
+ config["task"], config["category"], config["mode"],
1057
+ )
1058
+
1059
+
1060
+ # ============================================================
1061
+ # UI
1062
+ # ============================================================
1063
+ def create_demo():
1064
+ nv_green = gr.themes.Color(
1065
+ c50="#f7fbe8", c100="#eef7d1", c200="#ddf0a3",
1066
+ c300="#cce875", c400="#a4d422", c500="#76b900",
1067
+ c600="#649d00", c700="#527f00", c800="#3f6200",
1068
+ c900="#2d4400", c950="#1a2700",
1069
+ )
1070
+ with gr.Blocks(
1071
+ theme=gr.themes.Soft(primary_hue=nv_green, secondary_hue=nv_green),
1072
+ title="LocateAnything",
1073
+ ) as demo:
1074
+ gr.Markdown("# 🚀 LocateAnything")
1075
+ gr.Markdown(
1076
+ "> **Locate any object in images or videos with natural language.** \n"
1077
+ "> Upload an image/video on the left, choose a task type, enter what you want to find, "
1078
+ "then click **Run Inference**. Results with bounding boxes will appear on the right.\n"
1079
+ ">\n"
1080
+ "> **Quick Start:** "
1081
+ "① Select *Image* or *Video* → "
1082
+ "② Pick a *Task Type* (Detection / Grounding / OCR / GUI / Pointing) → "
1083
+ "③ Type your *Categories* (comma-separated) → "
1084
+ "④ Click **🧠 Run Inference**"
1085
+ )
1086
+
1087
+ with gr.Row():
1088
+ # ===== COL 1: Settings =====
1089
+ with gr.Column(scale=1):
1090
+ gr.Markdown("### ⚙️ Settings")
1091
+ input_type = gr.Radio(
1092
+ ["Image", "Video"], label="1. Input Media Type", value="Image",
1093
+ info="Select whether to process a single image or a video clip.",
1094
+ )
1095
+ task_dropdown = gr.Dropdown(
1096
+ choices=["Detection", "Grounding", "OCR", "GUI", "Pointing"],
1097
+ value="Detection", label="2. Task Type",
1098
+ info="Detection: find all instances | Grounding: match description | "
1099
+ "OCR: extract text | GUI: locate UI element | Pointing: point to target",
1100
+ )
1101
+ category_input = gr.Textbox(
1102
+ label="3. Categories",
1103
+ value="car, bus, person, potted plant",
1104
+ placeholder="e.g. car, person, dog (comma-separated, supports Chinese)",
1105
+ info="Enter one or more categories separated by commas. "
1106
+ "Supports both English and Chinese (e.g. 汽车, 行人).",
1107
+ )
1108
+ model_dropdown = gr.Dropdown(
1109
+ choices=["fast", "slow", "hybrid"],
1110
+ value="hybrid", label="4. Inference Mode",
1111
+ info="fast: MTP parallel decoding | slow: standard AR decoding | "
1112
+ "hybrid: auto-switch for best quality-speed balance",
1113
+ )
1114
+ with gr.Accordion("5. Advanced Settings", open=False):
1115
+ gr.Markdown(
1116
+ "*Adjust these only if needed. Default values work well for most cases.*"
1117
+ )
1118
+ temp_slider = gr.Slider(
1119
+ minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature",
1120
+ info="Higher = more diverse results; lower = more deterministic.",
1121
+ )
1122
+ top_p_slider = gr.Slider(
1123
+ minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top P",
1124
+ info="Nucleus sampling threshold.",
1125
+ )
1126
+ top_k_slider = gr.Slider(
1127
+ minimum=1, maximum=100, value=20, step=1, label="Top K",
1128
+ info="Top-K sampling: number of highest probability tokens to consider.",
1129
+ )
1130
+ short_size_input = gr.Number(
1131
+ label="Short Side Size (px)", value=None, precision=0,
1132
+ info="Resize the short side of the image to this value before inference. "
1133
+ "Leave empty to keep original size (auto-capped at 1024).",
1134
+ )
1135
+ max_video_frames_slider = gr.Slider(
1136
+ minimum=1, maximum=10, value=4, step=1,
1137
+ label="Max Video Frames",
1138
+ info="Number of frames to sample from the video for inference. "
1139
+ "Each frame takes ~15-20s. Keep ≤ 6 to avoid GPU timeout.",
1140
+ )
1141
+ run_btn = gr.Button("🧠 Run Inference", variant="primary", size="lg")
1142
+
1143
+ # ===== COL 2: Main =====
1144
+ with gr.Column(scale=3):
1145
+ with gr.Row():
1146
+ with gr.Column(scale=1):
1147
+ gr.Markdown("### 📥 Input Media")
1148
+ image_input = gr.Image(
1149
+ label="Input Image", type="pil", visible=True,
1150
+ )
1151
+ video_input = gr.Video(
1152
+ label="Input Video",
1153
+ visible=False,
1154
+ )
1155
+ with gr.Column(scale=1):
1156
+ gr.Markdown("### 📤 Output Result")
1157
+ output_image = gr.Image(
1158
+ label="Detection Result", type="pil", visible=True,
1159
+ )
1160
+ output_video = gr.Video(
1161
+ label="Video Result", visible=False,
1162
+ )
1163
+
1164
+ gr.Markdown("### 📝 Raw Input Prompt")
1165
+ raw_prompt_box = gr.Textbox(
1166
+ value=generate_raw_prompt("Detection", "car, bus, person, potted plant"),
1167
+ interactive=False, lines=2,
1168
+ info="This is the prompt sent to the model (auto-generated from your settings above).",
1169
+ )
1170
+ gr.Markdown("### 🔍 Decoding Visualization")
1171
+ raw_output_box = gr.HTML(label="Decoding Steps")
1172
+
1173
+ # ===== EXAMPLES =====
1174
+ gr.Markdown("---")
1175
+ gr.Markdown(
1176
+ "## 🖼️ Examples\n"
1177
+ "Click any example below to auto-fill the settings and input image."
1178
+ )
1179
+ gallery_images, gallery_captions = prepare_gallery_data()
1180
+ example_gallery = gr.Gallery(
1181
+ value=list(zip(gallery_images, gallery_captions)),
1182
+ show_label=True, columns=4, rows=1, height="auto", allow_preview=False,
1183
+ )
1184
+
1185
+ # ===== EVENTS =====
1186
+ input_type.change(
1187
+ fn=lambda c: (gr.update(visible=(c == "Image")), gr.update(visible=(c == "Video"))),
1188
+ inputs=input_type, outputs=[image_input, video_input],
1189
+ )
1190
+
1191
+ for comp in [task_dropdown, category_input]:
1192
+ comp.change(
1193
+ fn=generate_raw_prompt,
1194
+ inputs=[task_dropdown, category_input],
1195
+ outputs=raw_prompt_box,
1196
+ )
1197
+
1198
+ run_btn.click(
1199
+ fn=_disable_run_btn,
1200
+ inputs=None,
1201
+ outputs=[run_btn],
1202
+ ).then(
1203
+ fn=run_inference,
1204
+ inputs=[
1205
+ input_type, image_input, video_input,
1206
+ task_dropdown, category_input, model_dropdown,
1207
+ temp_slider, top_p_slider, top_k_slider,
1208
+ short_size_input, raw_prompt_box,
1209
+ max_video_frames_slider,
1210
+ ],
1211
+ outputs=[output_image, output_video, raw_output_box],
1212
+ ).then(
1213
+ fn=_enable_run_btn,
1214
+ inputs=None,
1215
+ outputs=[run_btn],
1216
+ )
1217
+
1218
+ example_gallery.select(
1219
+ fn=update_example_selection,
1220
+ outputs=[input_type, image_input, video_input,
1221
+ task_dropdown, category_input, model_dropdown],
1222
+ ).then(
1223
+ fn=generate_raw_prompt,
1224
+ inputs=[task_dropdown, category_input],
1225
+ outputs=raw_prompt_box,
1226
+ )
1227
+
1228
+ return demo
1229
+
1230
+
1231
+ if __name__ == "__main__":
1232
+ demo = create_demo()
1233
+ demo.launch(debug=True)
assets/LXGWWenKai-Bold.ttf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a356605eb36c92e29cc64090e2856e4675694572a2ae4da84adbadffecae907
3
+ size 18546748
assets/book.jpg ADDED

Git LFS Details

  • SHA256: fc0a3d0fde90c19697ea7901d92213ecd6de3dce4e2024af8ce579dd4cee99f3
  • Pointer size: 130 Bytes
  • Size of remote file: 47.9 kB
assets/ocr.jpg ADDED

Git LFS Details

  • SHA256: 9688a9ce343d6352e4ce1d7e5e7111bb7e500dac130f57889a4eb47c6cf056cc
  • Pointer size: 130 Bytes
  • Size of remote file: 24.5 kB
assets/person.jpg ADDED

Git LFS Details

  • SHA256: 1b500616480e629cb8418d3a542ec260e75813d4343869f50f294fc4f73f7e9f
  • Pointer size: 131 Bytes
  • Size of remote file: 703 kB
assets/sweet.jpg ADDED

Git LFS Details

  • SHA256: 0cbd03dc94f12129919b4edb8e8415f6ba649ec13ce4db2924ce42cb83ad96d2
  • Pointer size: 130 Bytes
  • Size of remote file: 37 kB
gitattributes.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ opencv-python==4.11.0.86
2
+ opencv-python-headless==4.11.0.86
3
+ transformers==4.51.0
4
+ torch==2.5.0
5
+ numpy==1.25.0
6
+ Pillow==11.1.0
7
+ peft
8
+ torchvision
9
+ decord==0.6.0
10
+ lmdb==1.7.5