Update modeling_moss_vl.py

#2
by CCCCyx - opened
Files changed (1) hide show
  1. modeling_moss_vl.py +1006 -1
modeling_moss_vl.py CHANGED
@@ -14,8 +14,11 @@
14
  # limitations under the License.
15
  """PyTorch MossVL model - Qwen3VL Vision + Text with Cross Attention"""
16
 
 
 
 
17
  from dataclasses import dataclass
18
- from typing import Any, Callable, Optional, Union, Tuple, List
19
 
20
  import torch
21
  import torch.nn as nn
@@ -26,6 +29,8 @@ from transformers import initialization as init
26
  from transformers.activations import ACT2FN
27
  from transformers.cache_utils import Cache, DynamicCache
28
  from transformers.generation import GenerationMixin
 
 
29
  from transformers.integrations import use_kernel_forward_from_hub
30
  from transformers.masking_utils import create_causal_mask
31
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
@@ -46,6 +51,59 @@ from .configuration_moss_vl import MossVLConfig, MossVLTextConfig, MossVLVisionC
46
  logger = logging.get_logger(__name__)
47
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  @dataclass
50
  class MossVLModelOutputWithPast(ModelOutput):
51
  """
@@ -2098,10 +2156,18 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
2098
  config: MossVLConfig
2099
  _checkpoint_conversion_mapping = {}
2100
  accepts_loss_kwargs = False
 
 
 
 
 
 
 
2101
  def __init__(self, config):
2102
  super().__init__(config)
2103
  self.model = MossVLModel(config)
2104
  self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
 
2105
 
2106
  self.post_init()
2107
 
@@ -2333,6 +2399,945 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
2333
 
2334
  return model_kwargs
2335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2336
 
2337
  __all__ = [
2338
  "MossVLVisionModel",
 
14
  # limitations under the License.
15
  """PyTorch MossVL model - Qwen3VL Vision + Text with Cross Attention"""
16
 
17
+ import copy
18
+ import queue
19
+ import threading
20
  from dataclasses import dataclass
21
+ from typing import Any, Callable, Dict, Optional, Union, Tuple, List
22
 
23
  import torch
24
  import torch.nn as nn
 
29
  from transformers.activations import ACT2FN
30
  from transformers.cache_utils import Cache, DynamicCache
31
  from transformers.generation import GenerationMixin
32
+ from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList
33
+ from transformers.generation.streamers import TextIteratorStreamer
34
  from transformers.integrations import use_kernel_forward_from_hub
35
  from transformers.masking_utils import create_causal_mask
36
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
 
51
  logger = logging.get_logger(__name__)
52
 
53
 
54
+ _OFFLINE_SYSTEM_PROMPTS = {
55
+ "no_thinking": {
56
+ "text_image": "You are a helpful AI assistant. Respond to the user's request based on the provided text and/or images.",
57
+ "video": "You are a helpful AI assistant specializing in video analysis. Respond to the user's request based on the provided video content.",
58
+ },
59
+ "deep_thinking": {
60
+ "text_image": "A conversation between User and Assistant. The user makes a request, and the assistant responds to it based on the provided text and/or images. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <thinking></thinking> and <answer></answer> tags, respectively, i.e., <thinking>reasoning process here</thinking><answer>answer here</answer>.",
61
+ "video": "A conversation between User and Assistant specializing in video analysis. The user makes a request, and the assistant responds to it based on the provided video content. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <thinking></thinking> and <answer></answer> tags, respectively, i.e., <thinking>reasoning process here</thinking><answer>answer here</answer>.",
62
+ },
63
+ }
64
+
65
+
66
+ class _OfflineCancelStoppingCriteria(StoppingCriteria):
67
+ def __init__(self, cancel_event: threading.Event):
68
+ self.cancel_event = cancel_event
69
+
70
+ def __call__(self, input_ids, scores, **kwargs) -> bool:
71
+ return self.cancel_event.is_set()
72
+
73
+
74
+ class _OfflineQueueStreamer(TextIteratorStreamer):
75
+ def __init__(self, tokenizer, output_text_queue: "queue.Queue[str]"):
76
+ super().__init__(tokenizer, skip_prompt=True, skip_special_tokens=True)
77
+ self.output_text_queue = output_text_queue
78
+ self.collected_chunks: List[str] = []
79
+
80
+ def on_finalized_text(self, text: str, stream_end: bool = False):
81
+ if text:
82
+ self.collected_chunks.append(text)
83
+ self.output_text_queue.put(text)
84
+ super().on_finalized_text(text, stream_end=stream_end)
85
+
86
+
87
+ _OFFLINE_THINKING_MODE_ALIASES = {
88
+ "no_thinking": "no_thinking",
89
+ "default": "no_thinking",
90
+ "standard": "no_thinking",
91
+ "deep_thinking": "deep_thinking",
92
+ "thinking": "deep_thinking",
93
+ "reasoning": "deep_thinking",
94
+ }
95
+
96
+ _OFFLINE_SYSTEM_PROMPT_TYPE_ALIASES = {
97
+ "text_image": "text_image",
98
+ "text-image": "text_image",
99
+ "image_text": "text_image",
100
+ "image-text": "text_image",
101
+ "text": "text_image",
102
+ "image": "text_image",
103
+ "video": "video",
104
+ }
105
+
106
+
107
  @dataclass
108
  class MossVLModelOutputWithPast(ModelOutput):
109
  """
 
2156
  config: MossVLConfig
2157
  _checkpoint_conversion_mapping = {}
2158
  accepts_loss_kwargs = False
2159
+
2160
+ @classmethod
2161
+ def build_offline_prepare_helper(cls):
2162
+ helper = cls.__new__(cls)
2163
+ helper._offline_processor_lock = threading.RLock()
2164
+ return helper
2165
+
2166
  def __init__(self, config):
2167
  super().__init__(config)
2168
  self.model = MossVLModel(config)
2169
  self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
2170
+ self._offline_processor_lock = threading.RLock()
2171
 
2172
  self.post_init()
2173
 
 
2399
 
2400
  return model_kwargs
2401
 
2402
+ # ==================== Offline generate orchestration ====================
2403
+ # The following helpers replicate the ``offline_*`` API exposed by the
2404
+ # previous release of ``modeling_moss_vl.py`` (transformers==4.57.1) so
2405
+ # external runners (e.g. ``inference/run_inference.py``) keep working
2406
+ # against this newer implementation without code changes. They ultimately
2407
+ # dispatch through ``self.generate(...)`` and the checkpoint's processor,
2408
+ # which keeps the decoding logic identical across versions.
2409
+
2410
+ @staticmethod
2411
+ def _offline_flatten_content_with_vision_tokens(content) -> str:
2412
+ if isinstance(content, str):
2413
+ return content
2414
+ if not isinstance(content, list):
2415
+ return str(content) if content else ""
2416
+
2417
+ parts = []
2418
+ for item in content:
2419
+ if isinstance(item, dict):
2420
+ if item.get("type") == "image" or "image" in item:
2421
+ parts.append("<|image|>")
2422
+ elif item.get("type") == "video" or "video" in item:
2423
+ parts.append("<|video|>")
2424
+ if "text" in item:
2425
+ parts.append(str(item["text"]))
2426
+ elif isinstance(item, str):
2427
+ parts.append(item)
2428
+ return "".join(parts)
2429
+
2430
+ @staticmethod
2431
+ def _offline_sanitize_prompt_text(processor, text: Any) -> str:
2432
+ if text is None:
2433
+ return ""
2434
+
2435
+ sanitized = str(text)
2436
+ replacements = [
2437
+ (getattr(processor, "image_placeholder", None), ""),
2438
+ (getattr(processor, "video_placeholder", None), ""),
2439
+ (getattr(processor, "image_token", None), ""),
2440
+ (getattr(processor, "video_token", None), ""),
2441
+ ]
2442
+ for needle, replacement in replacements:
2443
+ if needle:
2444
+ sanitized = sanitized.replace(needle, replacement)
2445
+ return sanitized.lstrip("\n")
2446
+
2447
+ def _offline_sanitize_message_content(self, processor, content: Any) -> Any:
2448
+ if isinstance(content, str):
2449
+ return self._offline_sanitize_prompt_text(processor, content)
2450
+ if not isinstance(content, list):
2451
+ return content
2452
+
2453
+ sanitized_items = []
2454
+ for item in content:
2455
+ if isinstance(item, dict):
2456
+ item_copy = dict(item)
2457
+ if "text" in item_copy:
2458
+ item_copy["text"] = self._offline_sanitize_prompt_text(processor, item_copy.get("text"))
2459
+ sanitized_items.append(item_copy)
2460
+ elif isinstance(item, str):
2461
+ sanitized_items.append(self._offline_sanitize_prompt_text(processor, item))
2462
+ else:
2463
+ sanitized_items.append(item)
2464
+ return sanitized_items
2465
+
2466
+ def _offline_prepare_messages(self, processor, query: Dict[str, Any]) -> List[Dict[str, Any]]:
2467
+ messages = query.get("messages")
2468
+ if messages:
2469
+ prepared_messages = []
2470
+ for message in messages:
2471
+ if not isinstance(message, dict):
2472
+ continue
2473
+ message_copy = dict(message)
2474
+ message_copy["content"] = self._offline_sanitize_message_content(
2475
+ processor,
2476
+ message_copy.get("content", ""),
2477
+ )
2478
+ prepared_messages.append(message_copy)
2479
+ if prepared_messages:
2480
+ return prepared_messages
2481
+
2482
+ prompt = self._offline_sanitize_prompt_text(processor, query.get("prompt", ""))
2483
+ images = list(query.get("images") or [])
2484
+ videos = list(query.get("videos") or [])
2485
+
2486
+ content = []
2487
+ for image in images:
2488
+ content.append({"type": "image", "image": image})
2489
+ for video in videos:
2490
+ content.append({"type": "video", "video": video})
2491
+ if prompt:
2492
+ content.append({"type": "text", "text": prompt.lstrip("\n")})
2493
+
2494
+ if not content:
2495
+ content = [{"type": "text", "text": ""}]
2496
+
2497
+ return [{"role": "user", "content": content}]
2498
+
2499
+ def _offline_prepare_input_text(self, processor, messages: List[Dict[str, Any]]) -> str:
2500
+ processed_messages = []
2501
+ for message in messages:
2502
+ message_copy = dict(message)
2503
+ message_copy["content"] = self._offline_flatten_content_with_vision_tokens(
2504
+ message_copy.get("content", "")
2505
+ )
2506
+ processed_messages.append(message_copy)
2507
+ return processor.apply_chat_template(
2508
+ processed_messages,
2509
+ tokenize=False,
2510
+ add_generation_prompt=True,
2511
+ )
2512
+
2513
+ @staticmethod
2514
+ def _offline_collect_media(messages: List[Dict[str, Any]]) -> tuple[List[Any], List[Any]]:
2515
+ all_images: List[Any] = []
2516
+ all_videos: List[Any] = []
2517
+
2518
+ for message in messages:
2519
+ content = message.get("content")
2520
+ if isinstance(content, list):
2521
+ for item in content:
2522
+ if not isinstance(item, dict):
2523
+ continue
2524
+ if item.get("type") == "image" or "image" in item:
2525
+ image = item.get("image") or item.get("image_url")
2526
+ if image is not None:
2527
+ all_images.append(image)
2528
+ elif item.get("type") == "video" or "video" in item:
2529
+ video = item.get("video")
2530
+ if video is not None:
2531
+ all_videos.append(video)
2532
+
2533
+ return all_images, all_videos
2534
+
2535
+ def _offline_build_processor_kwargs(
2536
+ self,
2537
+ input_text: Union[str, List[str]],
2538
+ all_images: List[Any],
2539
+ all_videos: List[Any],
2540
+ media_kwargs: Dict[str, Any],
2541
+ ) -> Dict[str, Any]:
2542
+ processor_kwargs: Dict[str, Any] = {
2543
+ "text": input_text,
2544
+ "images": all_images or None,
2545
+ "videos": all_videos or None,
2546
+ "return_tensors": "pt",
2547
+ "padding": False,
2548
+ }
2549
+
2550
+ if media_kwargs.get("min_pixels") is not None:
2551
+ processor_kwargs["min_pixels"] = media_kwargs["min_pixels"]
2552
+ if media_kwargs.get("max_pixels") is not None:
2553
+ processor_kwargs["max_pixels"] = media_kwargs["max_pixels"]
2554
+ if media_kwargs.get("video_fps") is not None:
2555
+ processor_kwargs["video_fps"] = media_kwargs["video_fps"]
2556
+
2557
+ min_frames = media_kwargs.get("min_frames", media_kwargs.get("video_minlen"))
2558
+ max_frames = media_kwargs.get("max_frames", media_kwargs.get("video_maxlen"))
2559
+ if min_frames is not None:
2560
+ processor_kwargs["min_frames"] = min_frames
2561
+ if max_frames is not None:
2562
+ processor_kwargs["max_frames"] = max_frames
2563
+
2564
+ return processor_kwargs
2565
+
2566
+ def _offline_run_processor(self, processor, processor_kwargs: Dict[str, Any], media_kwargs: Dict[str, Any]):
2567
+ image_proc = getattr(processor, "image_processor", None)
2568
+ video_proc = getattr(processor, "video_processor", None)
2569
+ modified_multi_image = False
2570
+ modified_video = False
2571
+
2572
+ with self._offline_processor_lock:
2573
+ try:
2574
+ multi_image_max_pixels = media_kwargs.get("multi_image_max_pixels")
2575
+ if multi_image_max_pixels is not None and image_proc is not None:
2576
+ orig_multi_image_max_pixels = getattr(image_proc, "multi_image_max_pixels", None)
2577
+ image_proc.multi_image_max_pixels = multi_image_max_pixels
2578
+ modified_multi_image = True
2579
+
2580
+ video_max_pixels = media_kwargs.get("video_max_pixels")
2581
+ if video_max_pixels is not None and video_proc is not None:
2582
+ orig_video_max_pixels = getattr(video_proc, "video_max_pixels", None)
2583
+ video_proc.video_max_pixels = video_max_pixels
2584
+ modified_video = True
2585
+
2586
+ inputs = processor(**processor_kwargs)
2587
+ finally:
2588
+ if modified_multi_image and image_proc is not None:
2589
+ image_proc.multi_image_max_pixels = orig_multi_image_max_pixels
2590
+ if modified_video and video_proc is not None:
2591
+ video_proc.video_max_pixels = orig_video_max_pixels
2592
+
2593
+ return inputs
2594
+
2595
+ def _offline_move_inputs_to_devices(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
2596
+ moved_inputs = dict(inputs)
2597
+ text_device = self.get_input_embeddings().weight.device
2598
+ vision_device = self.visual.patch_embed.proj.weight.device
2599
+ vision_input_keys = {"pixel_values", "grid_thw"}
2600
+
2601
+ for key, value in list(moved_inputs.items()):
2602
+ if not isinstance(value, torch.Tensor):
2603
+ continue
2604
+
2605
+ target_device = vision_device if key in vision_input_keys else text_device
2606
+ moved_value = value.to(target_device)
2607
+ if moved_value.dtype == torch.float32:
2608
+ moved_value = moved_value.to(torch.bfloat16)
2609
+ moved_inputs[key] = moved_value
2610
+
2611
+ return moved_inputs
2612
+
2613
+ @staticmethod
2614
+ def _offline_build_call_kwargs(generate_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]:
2615
+ normalized_generate_kwargs = dict(generate_kwargs or {})
2616
+ max_new_tokens = normalized_generate_kwargs.pop("max_new_tokens", 1024)
2617
+ temperature = normalized_generate_kwargs.pop("temperature", 1.0)
2618
+ top_k = normalized_generate_kwargs.pop("top_k", 50)
2619
+ top_p = normalized_generate_kwargs.pop("top_p", 1.0)
2620
+ repetition_penalty = normalized_generate_kwargs.pop("repetition_penalty", 1.0)
2621
+ do_sample = normalized_generate_kwargs.pop("do_sample", False)
2622
+ # ``vision_chunked_length`` was used by the previous modeling file to
2623
+ # shard the visual tower at prefill time. The current forward path
2624
+ # processes the entire vision input in one go, so this flag is
2625
+ # intentionally accepted-and-ignored here for backward compatibility.
2626
+ normalized_generate_kwargs.pop("vision_chunked_length", None)
2627
+
2628
+ if temperature is None:
2629
+ temperature = 1.0
2630
+ if temperature <= 0:
2631
+ temperature = 1.0
2632
+ do_sample = False
2633
+
2634
+ return dict(
2635
+ max_new_tokens=max_new_tokens,
2636
+ temperature=temperature,
2637
+ top_k=top_k,
2638
+ top_p=top_p,
2639
+ repetition_penalty=repetition_penalty,
2640
+ do_sample=do_sample,
2641
+ **normalized_generate_kwargs,
2642
+ )
2643
+
2644
+ def offline_prepare_query_cpu(
2645
+ self,
2646
+ processor,
2647
+ query: Dict[str, Any],
2648
+ session_messages: Optional[List[Dict[str, Any]]] = None,
2649
+ *,
2650
+ padding: bool = False,
2651
+ ) -> Dict[str, Any]:
2652
+ current_session = session_messages or []
2653
+ if query.get("reset_session") or query.get("clear_history"):
2654
+ current_session = []
2655
+
2656
+ working_messages = self._offline_build_session_messages(
2657
+ processor,
2658
+ query,
2659
+ current_session,
2660
+ )
2661
+ input_text = self._offline_prepare_input_text(processor, working_messages)
2662
+ all_images, all_videos = self._offline_collect_media(working_messages)
2663
+ media_kwargs = dict(query.get("media_kwargs") or {})
2664
+ processor_kwargs = self._offline_build_processor_kwargs(
2665
+ input_text,
2666
+ all_images,
2667
+ all_videos,
2668
+ media_kwargs,
2669
+ )
2670
+ processor_kwargs["padding"] = padding
2671
+ inputs_cpu = self._offline_run_processor(processor, processor_kwargs, media_kwargs)
2672
+
2673
+ return {
2674
+ "inputs_cpu": inputs_cpu,
2675
+ "input_text": input_text,
2676
+ "working_messages": working_messages,
2677
+ "call_kwargs": self._offline_build_call_kwargs(query.get("generate_kwargs")),
2678
+ }
2679
+
2680
+ def _offline_prepare_inputs(self, processor, query: Dict[str, Any]):
2681
+ prepared = self.offline_prepare_query_cpu(processor, query)
2682
+ inputs = self._offline_move_inputs_to_devices(prepared["inputs_cpu"])
2683
+ return inputs, prepared["input_text"]
2684
+
2685
+ def offline_generate_from_prepared(self, processor, prepared: Dict[str, Any]) -> Dict[str, Any]:
2686
+ inputs = self._offline_move_inputs_to_devices(prepared["inputs_cpu"])
2687
+ input_seq_len = inputs["input_ids"].shape[1]
2688
+
2689
+ with torch.no_grad():
2690
+ outputs = self.generate(
2691
+ **inputs,
2692
+ **prepared["call_kwargs"],
2693
+ )
2694
+
2695
+ generated_tokens = outputs[:, input_seq_len:]
2696
+ decoded_texts = processor.batch_decode(generated_tokens, skip_special_tokens=True)
2697
+ text = decoded_texts[0] if decoded_texts else ""
2698
+
2699
+ return {
2700
+ "text": text,
2701
+ "input_text": prepared["input_text"],
2702
+ "messages": prepared["working_messages"],
2703
+ }
2704
+
2705
+ def _offline_build_session_messages(
2706
+ self,
2707
+ processor,
2708
+ query: Dict[str, Any],
2709
+ session_messages: List[Dict[str, Any]],
2710
+ ) -> List[Dict[str, Any]]:
2711
+ has_explicit_messages = bool(query.get("messages"))
2712
+ if has_explicit_messages and not query.get("append_messages_to_session", False):
2713
+ base_messages: List[Dict[str, Any]] = []
2714
+ else:
2715
+ base_messages = [dict(message) for message in session_messages]
2716
+
2717
+ turn_messages = self._offline_prepare_messages(processor, query)
2718
+ has_system_message = any(
2719
+ isinstance(message, dict) and message.get("role") == "system"
2720
+ for message in (base_messages + turn_messages)
2721
+ )
2722
+
2723
+ should_add_system_prompt = (
2724
+ query.get("use_default_system_prompt", False)
2725
+ or query.get("system_prompt") is not None
2726
+ or query.get("system_prompt_type") is not None
2727
+ or query.get("thinking_mode") is not None
2728
+ )
2729
+
2730
+ if not base_messages and not has_system_message and should_add_system_prompt:
2731
+ system_prompt = self._offline_resolve_system_prompt(query, turn_messages)
2732
+ if system_prompt is not None:
2733
+ base_messages.append({"role": "system", "content": system_prompt})
2734
+
2735
+ return base_messages + turn_messages
2736
+
2737
+ @staticmethod
2738
+ def _offline_query_contains_video(query: Dict[str, Any], messages: List[Dict[str, Any]]) -> bool:
2739
+ if query.get("videos"):
2740
+ return True
2741
+
2742
+ for message in messages:
2743
+ content = message.get("content") if isinstance(message, dict) else None
2744
+ if isinstance(content, list) and any(
2745
+ isinstance(item, dict) and (item.get("type") == "video" or "video" in item)
2746
+ for item in content
2747
+ ):
2748
+ return True
2749
+ return False
2750
+
2751
+ @staticmethod
2752
+ def _offline_normalize_thinking_mode(value: Optional[str]) -> str:
2753
+ if value is None:
2754
+ return "no_thinking"
2755
+
2756
+ normalized = _OFFLINE_THINKING_MODE_ALIASES.get(str(value).strip().lower())
2757
+ if normalized is None:
2758
+ allowed = ", ".join(sorted(set(_OFFLINE_THINKING_MODE_ALIASES.values())))
2759
+ raise ValueError(f"Unsupported thinking_mode: {value!r}. Supported values: {allowed}")
2760
+ return normalized
2761
+
2762
+ @staticmethod
2763
+ def _offline_normalize_system_prompt_type(value: Optional[str], has_video: bool) -> str:
2764
+ if value is None:
2765
+ return "video" if has_video else "text_image"
2766
+
2767
+ normalized_key = str(value).strip().lower().replace("/", "_").replace(" ", "_")
2768
+ while "__" in normalized_key:
2769
+ normalized_key = normalized_key.replace("__", "_")
2770
+
2771
+ normalized = _OFFLINE_SYSTEM_PROMPT_TYPE_ALIASES.get(normalized_key)
2772
+ if normalized is None:
2773
+ allowed = ", ".join(sorted(set(_OFFLINE_SYSTEM_PROMPT_TYPE_ALIASES.values())))
2774
+ raise ValueError(f"Unsupported system_prompt_type: {value!r}. Supported values: {allowed}")
2775
+ return normalized
2776
+
2777
+ def _offline_resolve_system_prompt(
2778
+ self,
2779
+ query: Dict[str, Any],
2780
+ turn_messages: List[Dict[str, Any]],
2781
+ ) -> Optional[str]:
2782
+ explicit_system_prompt = query.get("system_prompt")
2783
+ if explicit_system_prompt is not None:
2784
+ return str(explicit_system_prompt)
2785
+
2786
+ has_video = self._offline_query_contains_video(query, turn_messages)
2787
+ thinking_mode = self._offline_normalize_thinking_mode(query.get("thinking_mode"))
2788
+ system_prompt_type = self._offline_normalize_system_prompt_type(
2789
+ query.get("system_prompt_type"),
2790
+ has_video=has_video,
2791
+ )
2792
+ return _OFFLINE_SYSTEM_PROMPTS[thinking_mode][system_prompt_type]
2793
+
2794
+ @staticmethod
2795
+ def _offline_finalize_session_messages(
2796
+ working_messages: List[Dict[str, Any]],
2797
+ assistant_text: str,
2798
+ ) -> List[Dict[str, Any]]:
2799
+ next_messages = [dict(message) for message in working_messages]
2800
+ next_messages.append({"role": "assistant", "content": assistant_text})
2801
+ return next_messages
2802
+
2803
+ def _offline_prepare_generation(self, processor, query: Dict[str, Any]):
2804
+ prepared = self.offline_prepare_query_cpu(processor, query)
2805
+ inputs = self._offline_move_inputs_to_devices(prepared["inputs_cpu"])
2806
+ return inputs, prepared["input_text"], prepared["call_kwargs"]
2807
+
2808
+ @staticmethod
2809
+ def _offline_normalize_shared_mapping(
2810
+ values: List[Dict[str, Any]],
2811
+ mapping_name: str,
2812
+ ) -> Dict[str, Any]:
2813
+ normalized_values = [dict(value or {}) for value in values]
2814
+ if not normalized_values:
2815
+ return {}
2816
+
2817
+ all_keys = set()
2818
+ for value in normalized_values:
2819
+ all_keys.update(value.keys())
2820
+
2821
+ merged: Dict[str, Any] = {}
2822
+ mismatched_keys: List[str] = []
2823
+ for key in sorted(all_keys):
2824
+ unique_values = {repr(value.get(key)) for value in normalized_values}
2825
+ if len(unique_values) > 1:
2826
+ mismatched_keys.append(key)
2827
+ else:
2828
+ merged[key] = normalized_values[0].get(key)
2829
+
2830
+ if mismatched_keys:
2831
+ mismatch_text = ", ".join(mismatched_keys)
2832
+ raise ValueError(
2833
+ f"All batch queries must share the same {mapping_name}. "
2834
+ f"Mismatched keys: {mismatch_text}"
2835
+ )
2836
+ return merged
2837
+
2838
+ def _offline_prepare_batch_generation(
2839
+ self,
2840
+ processor,
2841
+ queries: List[Dict[str, Any]],
2842
+ session_states: Optional[List[List[Dict[str, Any]]]] = None,
2843
+ ):
2844
+ if not queries:
2845
+ raise ValueError("`queries` must contain at least one query.")
2846
+
2847
+ if session_states is None:
2848
+ session_states = [[] for _ in queries]
2849
+ elif len(session_states) != len(queries):
2850
+ raise ValueError("`session_states` must have the same length as `queries`.")
2851
+
2852
+ working_messages_list: List[List[Dict[str, Any]]] = []
2853
+ input_texts: List[str] = []
2854
+ all_images_per_query: List[List[Any]] = []
2855
+ all_videos_per_query: List[List[Any]] = []
2856
+
2857
+ for query, session_state in zip(queries, session_states):
2858
+ if not isinstance(query, dict):
2859
+ raise TypeError("Each batch query must be a dict.")
2860
+ if query.get("stop_offline_generate"):
2861
+ raise ValueError("`stop_offline_generate` is not supported in offline_batch_generate.")
2862
+ if query.get("stream_output", query.get("stream", False)):
2863
+ raise ValueError("Streaming is not supported in offline_batch_generate.")
2864
+ if query.get("cancel_current_generate") or query.get("stop_generation"):
2865
+ raise ValueError("Cancel / stop controls are not supported in offline_batch_generate.")
2866
+
2867
+ current_session = [] if query.get("reset_session") or query.get("clear_history") else session_state
2868
+ working_messages = self._offline_build_session_messages(
2869
+ processor,
2870
+ query,
2871
+ current_session,
2872
+ )
2873
+ working_messages_list.append(working_messages)
2874
+ input_texts.append(self._offline_prepare_input_text(processor, working_messages))
2875
+
2876
+ all_images, all_videos = self._offline_collect_media(working_messages)
2877
+ all_images_per_query.append(all_images)
2878
+ all_videos_per_query.append(all_videos)
2879
+
2880
+ media_kwargs = self._offline_normalize_shared_mapping(
2881
+ [query.get("media_kwargs") or {} for query in queries],
2882
+ mapping_name="media_kwargs",
2883
+ )
2884
+ processor_kwargs = self._offline_build_processor_kwargs(
2885
+ input_text=input_texts,
2886
+ all_images=[image for images in all_images_per_query for image in images],
2887
+ all_videos=[video for videos in all_videos_per_query for video in videos],
2888
+ media_kwargs=media_kwargs,
2889
+ )
2890
+ processor_kwargs["padding"] = True
2891
+
2892
+ tokenizer = getattr(processor, "tokenizer", None)
2893
+ orig_padding_side = None
2894
+
2895
+ if tokenizer is not None and hasattr(tokenizer, "padding_side"):
2896
+ orig_padding_side = tokenizer.padding_side
2897
+ tokenizer.padding_side = "left"
2898
+ try:
2899
+ inputs = self._offline_run_processor(processor, processor_kwargs, media_kwargs)
2900
+ finally:
2901
+ if tokenizer is not None and orig_padding_side is not None:
2902
+ tokenizer.padding_side = orig_padding_side
2903
+
2904
+ inputs = self._offline_move_inputs_to_devices(inputs)
2905
+
2906
+ generate_kwargs = self._offline_normalize_shared_mapping(
2907
+ [query.get("generate_kwargs") or {} for query in queries],
2908
+ mapping_name="generate_kwargs",
2909
+ )
2910
+ call_kwargs = self._offline_build_call_kwargs(generate_kwargs)
2911
+ return inputs, input_texts, working_messages_list, call_kwargs
2912
+
2913
+ def offline_batch_generate(
2914
+ self,
2915
+ processor,
2916
+ queries: List[Dict[str, Any]],
2917
+ session_states: Optional[List[List[Dict[str, Any]]]] = None,
2918
+ vision_chunked_length: int = 64,
2919
+ ) -> Dict[str, Any]:
2920
+ """
2921
+ Batch offline generation for multiple independent samples.
2922
+
2923
+ This method supports:
2924
+ - batched single-turn generation
2925
+ - batched multi-turn continuation through `session_states`
2926
+
2927
+ It intentionally does not support queue-style controls such as:
2928
+ - `stream_output`
2929
+ - `cancel_current_generate`
2930
+ - `stop_generation`
2931
+ - `stop_offline_generate`
2932
+ """
2933
+ if not queries:
2934
+ return {"results": [], "session_states": []}
2935
+
2936
+ prepared_queries = [dict(query) for query in queries]
2937
+ for query in prepared_queries:
2938
+ generate_kwargs = query.setdefault("generate_kwargs", {})
2939
+ generate_kwargs.setdefault("vision_chunked_length", vision_chunked_length)
2940
+ if session_states is None:
2941
+ session_states = [[] for _ in prepared_queries]
2942
+ elif len(session_states) != len(prepared_queries):
2943
+ raise ValueError("`session_states` must have the same length as `queries`.")
2944
+
2945
+ tokenizer = getattr(processor, "tokenizer", None)
2946
+ bucketed_indices: Dict[Any, List[int]] = {}
2947
+ for index, (query, session_state) in enumerate(zip(prepared_queries, session_states)):
2948
+ current_session = [] if query.get("reset_session") or query.get("clear_history") else session_state
2949
+ working_messages = self._offline_build_session_messages(processor, query, current_session)
2950
+ input_text = self._offline_prepare_input_text(processor, working_messages)
2951
+
2952
+ if tokenizer is not None:
2953
+ token_ids = tokenizer(input_text, add_special_tokens=False)["input_ids"]
2954
+ bucket_key = len(token_ids)
2955
+ else:
2956
+ bucket_key = len(input_text)
2957
+ bucketed_indices.setdefault(bucket_key, []).append(index)
2958
+
2959
+ results: List[Optional[Dict[str, Any]]] = [None] * len(prepared_queries)
2960
+ next_session_states: List[Optional[List[Dict[str, Any]]]] = [None] * len(prepared_queries)
2961
+
2962
+ for bucket_indices in bucketed_indices.values():
2963
+ bucket_queries = [prepared_queries[index] for index in bucket_indices]
2964
+ bucket_session_states = [session_states[index] for index in bucket_indices]
2965
+ inputs, input_texts, working_messages_list, call_kwargs = self._offline_prepare_batch_generation(
2966
+ processor,
2967
+ bucket_queries,
2968
+ session_states=bucket_session_states,
2969
+ )
2970
+
2971
+ with torch.no_grad():
2972
+ outputs = self.generate(
2973
+ **inputs,
2974
+ **call_kwargs,
2975
+ )
2976
+
2977
+ input_seq_len = inputs["input_ids"].shape[1]
2978
+ generated_tokens = outputs[:, input_seq_len:]
2979
+ decoded_texts = processor.batch_decode(generated_tokens, skip_special_tokens=True)
2980
+
2981
+ for local_index, (query, input_text, working_messages, text) in enumerate(
2982
+ zip(bucket_queries, input_texts, working_messages_list, decoded_texts)
2983
+ ):
2984
+ original_index = bucket_indices[local_index]
2985
+ if query.get("persist_session", True):
2986
+ next_session_state = self._offline_finalize_session_messages(working_messages, text)
2987
+ else:
2988
+ next_session_state = working_messages
2989
+ next_session_states[original_index] = next_session_state
2990
+ results[original_index] = {
2991
+ "index": original_index,
2992
+ "text": text,
2993
+ "input_text": input_text,
2994
+ "messages": working_messages,
2995
+ }
2996
+
2997
+ return {
2998
+ "results": [item for item in results if item is not None],
2999
+ "session_states": [item for item in next_session_states if item is not None],
3000
+ }
3001
+
3002
+ def _offline_generate_one(self, processor, query: Dict[str, Any]) -> str:
3003
+ working_messages = self._offline_build_session_messages(processor, query, [])
3004
+ generation_query = dict(query)
3005
+ generation_query["messages"] = working_messages
3006
+ inputs, _, call_kwargs = self._offline_prepare_generation(processor, generation_query)
3007
+
3008
+ with torch.no_grad():
3009
+ outputs = self.generate(
3010
+ **inputs,
3011
+ **call_kwargs,
3012
+ )
3013
+
3014
+ new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
3015
+ return processor.decode(new_tokens, skip_special_tokens=True)
3016
+
3017
+ @staticmethod
3018
+ def _offline_capture_processor_attrs(target, overrides: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
3019
+ if target is None or not overrides:
3020
+ return None
3021
+ return {name: copy.deepcopy(getattr(target, name)) for name in overrides}
3022
+
3023
+ @staticmethod
3024
+ def _offline_apply_processor_attrs(target, overrides: Optional[Dict[str, Any]]) -> None:
3025
+ if target is None or not overrides:
3026
+ return
3027
+ for name, value in overrides.items():
3028
+ setattr(target, name, copy.deepcopy(value))
3029
+
3030
+ @staticmethod
3031
+ def _offline_restore_processor_attrs(target, snapshot: Optional[Dict[str, Any]]) -> None:
3032
+ if target is None or snapshot is None:
3033
+ return
3034
+ for name, value in snapshot.items():
3035
+ setattr(target, name, copy.deepcopy(value))
3036
+
3037
+ def _offline_generate_one_with_processor_overrides(
3038
+ self,
3039
+ processor,
3040
+ query: Dict[str, Any],
3041
+ image_processor_overrides: Optional[Dict[str, Any]] = None,
3042
+ video_processor_overrides: Optional[Dict[str, Any]] = None,
3043
+ ) -> str:
3044
+ image_proc = getattr(processor, "image_processor", None)
3045
+ video_proc = getattr(processor, "video_processor", None)
3046
+ image_snapshot = self._offline_capture_processor_attrs(image_proc, image_processor_overrides)
3047
+ video_snapshot = self._offline_capture_processor_attrs(video_proc, video_processor_overrides)
3048
+
3049
+ with self._offline_processor_lock:
3050
+ try:
3051
+ self._offline_apply_processor_attrs(image_proc, image_processor_overrides)
3052
+ self._offline_apply_processor_attrs(video_proc, video_processor_overrides)
3053
+ return self._offline_generate_one(processor, query)
3054
+ finally:
3055
+ self._offline_restore_processor_attrs(image_proc, image_snapshot)
3056
+ self._offline_restore_processor_attrs(video_proc, video_snapshot)
3057
+
3058
+ def offline_image_generate(
3059
+ self,
3060
+ processor,
3061
+ prompt: str,
3062
+ image: Any,
3063
+ *,
3064
+ shortest_edge: int = 4096,
3065
+ longest_edge: int = 16777216,
3066
+ multi_image_max_pixels: int = 201326592,
3067
+ patch_size: int = 16,
3068
+ temporal_patch_size: int = 1,
3069
+ merge_size: int = 2,
3070
+ image_mean: Optional[Union[List[float], Tuple[float, ...]]] = (0.5, 0.5, 0.5),
3071
+ image_std: Optional[Union[List[float], Tuple[float, ...]]] = (0.5, 0.5, 0.5),
3072
+ max_new_tokens: int = 1024,
3073
+ temperature: float = 1.0,
3074
+ top_k: int = 50,
3075
+ top_p: float = 1.0,
3076
+ repetition_penalty: float = 1.0,
3077
+ do_sample: bool = False,
3078
+ vision_chunked_length: int = 64,
3079
+ thinking_mode: Optional[str] = None,
3080
+ system_prompt_type: Optional[str] = None,
3081
+ system_prompt: Optional[str] = None,
3082
+ ) -> str:
3083
+ """
3084
+ Single-image offline generation with explicit image preprocessor defaults.
3085
+
3086
+ The default values mirror `preprocessor_config.json` so README examples can
3087
+ surface the full image preprocessing setup without requiring a batch wrapper.
3088
+ """
3089
+ query: Dict[str, Any] = {
3090
+ "prompt": prompt,
3091
+ "images": [image],
3092
+ "videos": [],
3093
+ "media_kwargs": {
3094
+ "min_pixels": shortest_edge,
3095
+ "max_pixels": longest_edge,
3096
+ "multi_image_max_pixels": multi_image_max_pixels,
3097
+ },
3098
+ "generate_kwargs": {
3099
+ "max_new_tokens": max_new_tokens,
3100
+ "temperature": temperature,
3101
+ "top_k": top_k,
3102
+ "top_p": top_p,
3103
+ "repetition_penalty": repetition_penalty,
3104
+ "do_sample": do_sample,
3105
+ "vision_chunked_length": vision_chunked_length,
3106
+ },
3107
+ }
3108
+ if thinking_mode is not None:
3109
+ query["thinking_mode"] = thinking_mode
3110
+ if system_prompt_type is not None:
3111
+ query["system_prompt_type"] = system_prompt_type
3112
+ if system_prompt is not None:
3113
+ query["system_prompt"] = system_prompt
3114
+
3115
+ image_processor_overrides = {
3116
+ "size": {"shortest_edge": shortest_edge, "longest_edge": longest_edge},
3117
+ "multi_image_max_pixels": multi_image_max_pixels,
3118
+ "patch_size": patch_size,
3119
+ "temporal_patch_size": temporal_patch_size,
3120
+ "merge_size": merge_size,
3121
+ "image_mean": list(image_mean) if image_mean is not None else None,
3122
+ "image_std": list(image_std) if image_std is not None else None,
3123
+ }
3124
+ return self._offline_generate_one_with_processor_overrides(
3125
+ processor,
3126
+ query,
3127
+ image_processor_overrides=image_processor_overrides,
3128
+ )
3129
+
3130
+ def offline_video_generate(
3131
+ self,
3132
+ processor,
3133
+ prompt: str,
3134
+ video: Any,
3135
+ *,
3136
+ shortest_edge: int = 4096,
3137
+ longest_edge: int = 16777216,
3138
+ video_max_pixels: int = 201326592,
3139
+ patch_size: int = 16,
3140
+ temporal_patch_size: int = 1,
3141
+ merge_size: int = 2,
3142
+ video_fps: float = 1.0,
3143
+ min_frames: int = 1,
3144
+ max_frames: int = 256,
3145
+ num_extract_threads: int = 4,
3146
+ image_mean: Optional[Union[List[float], Tuple[float, ...]]] = (0.5, 0.5, 0.5),
3147
+ image_std: Optional[Union[List[float], Tuple[float, ...]]] = (0.5, 0.5, 0.5),
3148
+ max_new_tokens: int = 1024,
3149
+ temperature: float = 1.0,
3150
+ top_k: int = 50,
3151
+ top_p: float = 1.0,
3152
+ repetition_penalty: float = 1.0,
3153
+ do_sample: bool = False,
3154
+ vision_chunked_length: int = 64,
3155
+ thinking_mode: Optional[str] = None,
3156
+ system_prompt_type: Optional[str] = None,
3157
+ system_prompt: Optional[str] = None,
3158
+ ) -> str:
3159
+ """
3160
+ Single-video offline generation with explicit video preprocessor defaults.
3161
+
3162
+ The default values mirror `video_preprocessor_config.json` so README examples
3163
+ can show a standalone video entry point with the effective preprocessing knobs.
3164
+ """
3165
+ query: Dict[str, Any] = {
3166
+ "prompt": prompt,
3167
+ "images": [],
3168
+ "videos": [video],
3169
+ "media_kwargs": {
3170
+ "min_pixels": shortest_edge,
3171
+ "max_pixels": longest_edge,
3172
+ "video_max_pixels": video_max_pixels,
3173
+ "video_fps": video_fps,
3174
+ "min_frames": min_frames,
3175
+ "max_frames": max_frames,
3176
+ },
3177
+ "generate_kwargs": {
3178
+ "max_new_tokens": max_new_tokens,
3179
+ "temperature": temperature,
3180
+ "top_k": top_k,
3181
+ "top_p": top_p,
3182
+ "repetition_penalty": repetition_penalty,
3183
+ "do_sample": do_sample,
3184
+ "vision_chunked_length": vision_chunked_length,
3185
+ },
3186
+ }
3187
+ if thinking_mode is not None:
3188
+ query["thinking_mode"] = thinking_mode
3189
+ if system_prompt_type is not None:
3190
+ query["system_prompt_type"] = system_prompt_type
3191
+ if system_prompt is not None:
3192
+ query["system_prompt"] = system_prompt
3193
+
3194
+ video_processor_overrides = {
3195
+ "size": {"shortest_edge": shortest_edge, "longest_edge": longest_edge},
3196
+ "video_max_pixels": video_max_pixels,
3197
+ "patch_size": patch_size,
3198
+ "temporal_patch_size": temporal_patch_size,
3199
+ "merge_size": merge_size,
3200
+ "video_fps": video_fps,
3201
+ "min_frames": min_frames,
3202
+ "max_frames": max_frames,
3203
+ "num_extract_threads": num_extract_threads,
3204
+ "image_mean": list(image_mean) if image_mean is not None else None,
3205
+ "image_std": list(image_std) if image_std is not None else None,
3206
+ }
3207
+ return self._offline_generate_one_with_processor_overrides(
3208
+ processor,
3209
+ query,
3210
+ video_processor_overrides=video_processor_overrides,
3211
+ )
3212
+
3213
+ def offline_generate(
3214
+ self,
3215
+ processor,
3216
+ new_queries: "queue.Queue[dict]",
3217
+ output_text_queue: "queue.Queue[str]",
3218
+ vision_chunked_length: int = 64,
3219
+ ) -> None:
3220
+ """
3221
+ HF-style offline inference wrapper aligned with the previous backend output path.
3222
+
3223
+ This method intentionally reuses the checkpoint's existing processor and
3224
+ `generate()` flow so that outputs stay consistent with the old external
3225
+ backend inference implementation.
3226
+
3227
+ Supported query keys include:
3228
+ - `prompt` / `messages`
3229
+ - `images` / `videos`
3230
+ - `media_kwargs` / `generate_kwargs`
3231
+ - `thinking_mode` (`no_thinking` or `deep_thinking`, plus compatible aliases)
3232
+ - `system_prompt_type` (`text_image` or `video`, plus compatible aliases)
3233
+ - `system_prompt` for an explicit override
3234
+ - `stream_output` / `stream`
3235
+ - `reset_session` / `clear_history`
3236
+ - `cancel_current_generate` / `stop_generation` / `stop_offline_generate`
3237
+ """
3238
+ buffered_queries: List[Dict[str, Any]] = []
3239
+ session_messages: List[Dict[str, Any]] = []
3240
+
3241
+ while True:
3242
+ if buffered_queries:
3243
+ query = buffered_queries.pop(0)
3244
+ else:
3245
+ query = new_queries.get()
3246
+ if not isinstance(query, dict):
3247
+ continue
3248
+
3249
+ if query.get("stop_offline_generate"):
3250
+ break
3251
+
3252
+ if query.get("reset_session") or query.get("clear_history"):
3253
+ session_messages = []
3254
+
3255
+ try:
3256
+ generate_kwargs = query.setdefault("generate_kwargs", {})
3257
+ generate_kwargs.setdefault("vision_chunked_length", vision_chunked_length)
3258
+ working_messages = self._offline_build_session_messages(
3259
+ processor,
3260
+ query,
3261
+ session_messages,
3262
+ )
3263
+
3264
+ generation_query = dict(query)
3265
+ generation_query["messages"] = working_messages
3266
+ inputs, input_text, call_kwargs = self._offline_prepare_generation(processor, generation_query)
3267
+
3268
+ stream_output = bool(query.get("stream_output", query.get("stream", False)))
3269
+ cancel_event = threading.Event()
3270
+ stopping_criteria = StoppingCriteriaList([_OfflineCancelStoppingCriteria(cancel_event)])
3271
+ generation_state: Dict[str, Any] = {}
3272
+
3273
+ if stream_output:
3274
+ output_text_queue.put("<|round_start|>")
3275
+ streamer = _OfflineQueueStreamer(getattr(processor, "tokenizer", processor), output_text_queue)
3276
+ else:
3277
+ streamer = None
3278
+
3279
+ def _run_generation():
3280
+ try:
3281
+ with torch.no_grad():
3282
+ generation_state["outputs"] = self.generate(
3283
+ **inputs,
3284
+ stopping_criteria=stopping_criteria,
3285
+ streamer=streamer,
3286
+ **call_kwargs,
3287
+ )
3288
+ except Exception as exc:
3289
+ generation_state["exception"] = exc
3290
+
3291
+ worker = threading.Thread(target=_run_generation, daemon=True)
3292
+ worker.start()
3293
+
3294
+ stop_conversation_after_turn = False
3295
+ while worker.is_alive():
3296
+ try:
3297
+ control_query = new_queries.get(timeout=0.1)
3298
+ except queue.Empty:
3299
+ continue
3300
+
3301
+ if not isinstance(control_query, dict):
3302
+ continue
3303
+
3304
+ if control_query.get("cancel_current_generate") or control_query.get("stop_generation"):
3305
+ cancel_event.set()
3306
+ stop_conversation_after_turn = stop_conversation_after_turn or control_query.get("stop_offline_generate", False)
3307
+ continue
3308
+
3309
+ if control_query.get("stop_offline_generate"):
3310
+ cancel_event.set()
3311
+ stop_conversation_after_turn = True
3312
+ continue
3313
+
3314
+ buffered_queries.append(control_query)
3315
+
3316
+ worker.join()
3317
+ was_cancelled = cancel_event.is_set()
3318
+
3319
+ if "exception" in generation_state:
3320
+ raise generation_state["exception"]
3321
+
3322
+ if stream_output and streamer is not None:
3323
+ text = "".join(streamer.collected_chunks)
3324
+ else:
3325
+ outputs = generation_state["outputs"]
3326
+ new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
3327
+ text = processor.decode(new_tokens, skip_special_tokens=True)
3328
+ output_text_queue.put(text)
3329
+
3330
+ if query.get("persist_session", True) and (not was_cancelled or query.get("persist_cancelled_turn", False)):
3331
+ session_messages = self._offline_finalize_session_messages(working_messages, text)
3332
+
3333
+ output_text_queue.put("<|round_end|>")
3334
+
3335
+ if stop_conversation_after_turn:
3336
+ break
3337
+ except Exception as exc:
3338
+ output_text_queue.put(f"[ERROR] {exc}")
3339
+ output_text_queue.put("<|round_end|>")
3340
+
3341
 
3342
  __all__ = [
3343
  "MossVLVisionModel",