Update modeling_moss_vl.py
#2
by CCCCyx - opened
- modeling_moss_vl.py +1006 -1
modeling_moss_vl.py
CHANGED
|
@@ -14,8 +14,11 @@
|
|
| 14 |
# limitations under the License.
|
| 15 |
"""PyTorch MossVL model - Qwen3VL Vision + Text with Cross Attention"""
|
| 16 |
|
|
|
|
|
|
|
|
|
|
| 17 |
from dataclasses import dataclass
|
| 18 |
-
from typing import Any, Callable, Optional, Union, Tuple, List
|
| 19 |
|
| 20 |
import torch
|
| 21 |
import torch.nn as nn
|
|
@@ -26,6 +29,8 @@ from transformers import initialization as init
|
|
| 26 |
from transformers.activations import ACT2FN
|
| 27 |
from transformers.cache_utils import Cache, DynamicCache
|
| 28 |
from transformers.generation import GenerationMixin
|
|
|
|
|
|
|
| 29 |
from transformers.integrations import use_kernel_forward_from_hub
|
| 30 |
from transformers.masking_utils import create_causal_mask
|
| 31 |
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
|
@@ -46,6 +51,59 @@ from .configuration_moss_vl import MossVLConfig, MossVLTextConfig, MossVLVisionC
|
|
| 46 |
logger = logging.get_logger(__name__)
|
| 47 |
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
@dataclass
|
| 50 |
class MossVLModelOutputWithPast(ModelOutput):
|
| 51 |
"""
|
|
@@ -2098,10 +2156,18 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
|
|
| 2098 |
config: MossVLConfig
|
| 2099 |
_checkpoint_conversion_mapping = {}
|
| 2100 |
accepts_loss_kwargs = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2101 |
def __init__(self, config):
|
| 2102 |
super().__init__(config)
|
| 2103 |
self.model = MossVLModel(config)
|
| 2104 |
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
|
|
|
| 2105 |
|
| 2106 |
self.post_init()
|
| 2107 |
|
|
@@ -2333,6 +2399,945 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
|
|
| 2333 |
|
| 2334 |
return model_kwargs
|
| 2335 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2336 |
|
| 2337 |
__all__ = [
|
| 2338 |
"MossVLVisionModel",
|
|
|
|
| 14 |
# limitations under the License.
|
| 15 |
"""PyTorch MossVL model - Qwen3VL Vision + Text with Cross Attention"""
|
| 16 |
|
| 17 |
+
import copy
|
| 18 |
+
import queue
|
| 19 |
+
import threading
|
| 20 |
from dataclasses import dataclass
|
| 21 |
+
from typing import Any, Callable, Dict, Optional, Union, Tuple, List
|
| 22 |
|
| 23 |
import torch
|
| 24 |
import torch.nn as nn
|
|
|
|
| 29 |
from transformers.activations import ACT2FN
|
| 30 |
from transformers.cache_utils import Cache, DynamicCache
|
| 31 |
from transformers.generation import GenerationMixin
|
| 32 |
+
from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList
|
| 33 |
+
from transformers.generation.streamers import TextIteratorStreamer
|
| 34 |
from transformers.integrations import use_kernel_forward_from_hub
|
| 35 |
from transformers.masking_utils import create_causal_mask
|
| 36 |
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
|
|
|
| 51 |
logger = logging.get_logger(__name__)
|
| 52 |
|
| 53 |
|
| 54 |
+
_OFFLINE_SYSTEM_PROMPTS = {
|
| 55 |
+
"no_thinking": {
|
| 56 |
+
"text_image": "You are a helpful AI assistant. Respond to the user's request based on the provided text and/or images.",
|
| 57 |
+
"video": "You are a helpful AI assistant specializing in video analysis. Respond to the user's request based on the provided video content.",
|
| 58 |
+
},
|
| 59 |
+
"deep_thinking": {
|
| 60 |
+
"text_image": "A conversation between User and Assistant. The user makes a request, and the assistant responds to it based on the provided text and/or images. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <thinking></thinking> and <answer></answer> tags, respectively, i.e., <thinking>reasoning process here</thinking><answer>answer here</answer>.",
|
| 61 |
+
"video": "A conversation between User and Assistant specializing in video analysis. The user makes a request, and the assistant responds to it based on the provided video content. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <thinking></thinking> and <answer></answer> tags, respectively, i.e., <thinking>reasoning process here</thinking><answer>answer here</answer>.",
|
| 62 |
+
},
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class _OfflineCancelStoppingCriteria(StoppingCriteria):
|
| 67 |
+
def __init__(self, cancel_event: threading.Event):
|
| 68 |
+
self.cancel_event = cancel_event
|
| 69 |
+
|
| 70 |
+
def __call__(self, input_ids, scores, **kwargs) -> bool:
|
| 71 |
+
return self.cancel_event.is_set()
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class _OfflineQueueStreamer(TextIteratorStreamer):
|
| 75 |
+
def __init__(self, tokenizer, output_text_queue: "queue.Queue[str]"):
|
| 76 |
+
super().__init__(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 77 |
+
self.output_text_queue = output_text_queue
|
| 78 |
+
self.collected_chunks: List[str] = []
|
| 79 |
+
|
| 80 |
+
def on_finalized_text(self, text: str, stream_end: bool = False):
|
| 81 |
+
if text:
|
| 82 |
+
self.collected_chunks.append(text)
|
| 83 |
+
self.output_text_queue.put(text)
|
| 84 |
+
super().on_finalized_text(text, stream_end=stream_end)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
_OFFLINE_THINKING_MODE_ALIASES = {
|
| 88 |
+
"no_thinking": "no_thinking",
|
| 89 |
+
"default": "no_thinking",
|
| 90 |
+
"standard": "no_thinking",
|
| 91 |
+
"deep_thinking": "deep_thinking",
|
| 92 |
+
"thinking": "deep_thinking",
|
| 93 |
+
"reasoning": "deep_thinking",
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
_OFFLINE_SYSTEM_PROMPT_TYPE_ALIASES = {
|
| 97 |
+
"text_image": "text_image",
|
| 98 |
+
"text-image": "text_image",
|
| 99 |
+
"image_text": "text_image",
|
| 100 |
+
"image-text": "text_image",
|
| 101 |
+
"text": "text_image",
|
| 102 |
+
"image": "text_image",
|
| 103 |
+
"video": "video",
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
@dataclass
|
| 108 |
class MossVLModelOutputWithPast(ModelOutput):
|
| 109 |
"""
|
|
|
|
| 2156 |
config: MossVLConfig
|
| 2157 |
_checkpoint_conversion_mapping = {}
|
| 2158 |
accepts_loss_kwargs = False
|
| 2159 |
+
|
| 2160 |
+
@classmethod
|
| 2161 |
+
def build_offline_prepare_helper(cls):
|
| 2162 |
+
helper = cls.__new__(cls)
|
| 2163 |
+
helper._offline_processor_lock = threading.RLock()
|
| 2164 |
+
return helper
|
| 2165 |
+
|
| 2166 |
def __init__(self, config):
|
| 2167 |
super().__init__(config)
|
| 2168 |
self.model = MossVLModel(config)
|
| 2169 |
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
| 2170 |
+
self._offline_processor_lock = threading.RLock()
|
| 2171 |
|
| 2172 |
self.post_init()
|
| 2173 |
|
|
|
|
| 2399 |
|
| 2400 |
return model_kwargs
|
| 2401 |
|
| 2402 |
+
# ==================== Offline generate orchestration ====================
|
| 2403 |
+
# The following helpers replicate the ``offline_*`` API exposed by the
|
| 2404 |
+
# previous release of ``modeling_moss_vl.py`` (transformers==4.57.1) so
|
| 2405 |
+
# external runners (e.g. ``inference/run_inference.py``) keep working
|
| 2406 |
+
# against this newer implementation without code changes. They ultimately
|
| 2407 |
+
# dispatch through ``self.generate(...)`` and the checkpoint's processor,
|
| 2408 |
+
# which keeps the decoding logic identical across versions.
|
| 2409 |
+
|
| 2410 |
+
@staticmethod
|
| 2411 |
+
def _offline_flatten_content_with_vision_tokens(content) -> str:
|
| 2412 |
+
if isinstance(content, str):
|
| 2413 |
+
return content
|
| 2414 |
+
if not isinstance(content, list):
|
| 2415 |
+
return str(content) if content else ""
|
| 2416 |
+
|
| 2417 |
+
parts = []
|
| 2418 |
+
for item in content:
|
| 2419 |
+
if isinstance(item, dict):
|
| 2420 |
+
if item.get("type") == "image" or "image" in item:
|
| 2421 |
+
parts.append("<|image|>")
|
| 2422 |
+
elif item.get("type") == "video" or "video" in item:
|
| 2423 |
+
parts.append("<|video|>")
|
| 2424 |
+
if "text" in item:
|
| 2425 |
+
parts.append(str(item["text"]))
|
| 2426 |
+
elif isinstance(item, str):
|
| 2427 |
+
parts.append(item)
|
| 2428 |
+
return "".join(parts)
|
| 2429 |
+
|
| 2430 |
+
@staticmethod
|
| 2431 |
+
def _offline_sanitize_prompt_text(processor, text: Any) -> str:
|
| 2432 |
+
if text is None:
|
| 2433 |
+
return ""
|
| 2434 |
+
|
| 2435 |
+
sanitized = str(text)
|
| 2436 |
+
replacements = [
|
| 2437 |
+
(getattr(processor, "image_placeholder", None), ""),
|
| 2438 |
+
(getattr(processor, "video_placeholder", None), ""),
|
| 2439 |
+
(getattr(processor, "image_token", None), ""),
|
| 2440 |
+
(getattr(processor, "video_token", None), ""),
|
| 2441 |
+
]
|
| 2442 |
+
for needle, replacement in replacements:
|
| 2443 |
+
if needle:
|
| 2444 |
+
sanitized = sanitized.replace(needle, replacement)
|
| 2445 |
+
return sanitized.lstrip("\n")
|
| 2446 |
+
|
| 2447 |
+
def _offline_sanitize_message_content(self, processor, content: Any) -> Any:
|
| 2448 |
+
if isinstance(content, str):
|
| 2449 |
+
return self._offline_sanitize_prompt_text(processor, content)
|
| 2450 |
+
if not isinstance(content, list):
|
| 2451 |
+
return content
|
| 2452 |
+
|
| 2453 |
+
sanitized_items = []
|
| 2454 |
+
for item in content:
|
| 2455 |
+
if isinstance(item, dict):
|
| 2456 |
+
item_copy = dict(item)
|
| 2457 |
+
if "text" in item_copy:
|
| 2458 |
+
item_copy["text"] = self._offline_sanitize_prompt_text(processor, item_copy.get("text"))
|
| 2459 |
+
sanitized_items.append(item_copy)
|
| 2460 |
+
elif isinstance(item, str):
|
| 2461 |
+
sanitized_items.append(self._offline_sanitize_prompt_text(processor, item))
|
| 2462 |
+
else:
|
| 2463 |
+
sanitized_items.append(item)
|
| 2464 |
+
return sanitized_items
|
| 2465 |
+
|
| 2466 |
+
def _offline_prepare_messages(self, processor, query: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 2467 |
+
messages = query.get("messages")
|
| 2468 |
+
if messages:
|
| 2469 |
+
prepared_messages = []
|
| 2470 |
+
for message in messages:
|
| 2471 |
+
if not isinstance(message, dict):
|
| 2472 |
+
continue
|
| 2473 |
+
message_copy = dict(message)
|
| 2474 |
+
message_copy["content"] = self._offline_sanitize_message_content(
|
| 2475 |
+
processor,
|
| 2476 |
+
message_copy.get("content", ""),
|
| 2477 |
+
)
|
| 2478 |
+
prepared_messages.append(message_copy)
|
| 2479 |
+
if prepared_messages:
|
| 2480 |
+
return prepared_messages
|
| 2481 |
+
|
| 2482 |
+
prompt = self._offline_sanitize_prompt_text(processor, query.get("prompt", ""))
|
| 2483 |
+
images = list(query.get("images") or [])
|
| 2484 |
+
videos = list(query.get("videos") or [])
|
| 2485 |
+
|
| 2486 |
+
content = []
|
| 2487 |
+
for image in images:
|
| 2488 |
+
content.append({"type": "image", "image": image})
|
| 2489 |
+
for video in videos:
|
| 2490 |
+
content.append({"type": "video", "video": video})
|
| 2491 |
+
if prompt:
|
| 2492 |
+
content.append({"type": "text", "text": prompt.lstrip("\n")})
|
| 2493 |
+
|
| 2494 |
+
if not content:
|
| 2495 |
+
content = [{"type": "text", "text": ""}]
|
| 2496 |
+
|
| 2497 |
+
return [{"role": "user", "content": content}]
|
| 2498 |
+
|
| 2499 |
+
def _offline_prepare_input_text(self, processor, messages: List[Dict[str, Any]]) -> str:
|
| 2500 |
+
processed_messages = []
|
| 2501 |
+
for message in messages:
|
| 2502 |
+
message_copy = dict(message)
|
| 2503 |
+
message_copy["content"] = self._offline_flatten_content_with_vision_tokens(
|
| 2504 |
+
message_copy.get("content", "")
|
| 2505 |
+
)
|
| 2506 |
+
processed_messages.append(message_copy)
|
| 2507 |
+
return processor.apply_chat_template(
|
| 2508 |
+
processed_messages,
|
| 2509 |
+
tokenize=False,
|
| 2510 |
+
add_generation_prompt=True,
|
| 2511 |
+
)
|
| 2512 |
+
|
| 2513 |
+
@staticmethod
|
| 2514 |
+
def _offline_collect_media(messages: List[Dict[str, Any]]) -> tuple[List[Any], List[Any]]:
|
| 2515 |
+
all_images: List[Any] = []
|
| 2516 |
+
all_videos: List[Any] = []
|
| 2517 |
+
|
| 2518 |
+
for message in messages:
|
| 2519 |
+
content = message.get("content")
|
| 2520 |
+
if isinstance(content, list):
|
| 2521 |
+
for item in content:
|
| 2522 |
+
if not isinstance(item, dict):
|
| 2523 |
+
continue
|
| 2524 |
+
if item.get("type") == "image" or "image" in item:
|
| 2525 |
+
image = item.get("image") or item.get("image_url")
|
| 2526 |
+
if image is not None:
|
| 2527 |
+
all_images.append(image)
|
| 2528 |
+
elif item.get("type") == "video" or "video" in item:
|
| 2529 |
+
video = item.get("video")
|
| 2530 |
+
if video is not None:
|
| 2531 |
+
all_videos.append(video)
|
| 2532 |
+
|
| 2533 |
+
return all_images, all_videos
|
| 2534 |
+
|
| 2535 |
+
def _offline_build_processor_kwargs(
|
| 2536 |
+
self,
|
| 2537 |
+
input_text: Union[str, List[str]],
|
| 2538 |
+
all_images: List[Any],
|
| 2539 |
+
all_videos: List[Any],
|
| 2540 |
+
media_kwargs: Dict[str, Any],
|
| 2541 |
+
) -> Dict[str, Any]:
|
| 2542 |
+
processor_kwargs: Dict[str, Any] = {
|
| 2543 |
+
"text": input_text,
|
| 2544 |
+
"images": all_images or None,
|
| 2545 |
+
"videos": all_videos or None,
|
| 2546 |
+
"return_tensors": "pt",
|
| 2547 |
+
"padding": False,
|
| 2548 |
+
}
|
| 2549 |
+
|
| 2550 |
+
if media_kwargs.get("min_pixels") is not None:
|
| 2551 |
+
processor_kwargs["min_pixels"] = media_kwargs["min_pixels"]
|
| 2552 |
+
if media_kwargs.get("max_pixels") is not None:
|
| 2553 |
+
processor_kwargs["max_pixels"] = media_kwargs["max_pixels"]
|
| 2554 |
+
if media_kwargs.get("video_fps") is not None:
|
| 2555 |
+
processor_kwargs["video_fps"] = media_kwargs["video_fps"]
|
| 2556 |
+
|
| 2557 |
+
min_frames = media_kwargs.get("min_frames", media_kwargs.get("video_minlen"))
|
| 2558 |
+
max_frames = media_kwargs.get("max_frames", media_kwargs.get("video_maxlen"))
|
| 2559 |
+
if min_frames is not None:
|
| 2560 |
+
processor_kwargs["min_frames"] = min_frames
|
| 2561 |
+
if max_frames is not None:
|
| 2562 |
+
processor_kwargs["max_frames"] = max_frames
|
| 2563 |
+
|
| 2564 |
+
return processor_kwargs
|
| 2565 |
+
|
| 2566 |
+
def _offline_run_processor(self, processor, processor_kwargs: Dict[str, Any], media_kwargs: Dict[str, Any]):
|
| 2567 |
+
image_proc = getattr(processor, "image_processor", None)
|
| 2568 |
+
video_proc = getattr(processor, "video_processor", None)
|
| 2569 |
+
modified_multi_image = False
|
| 2570 |
+
modified_video = False
|
| 2571 |
+
|
| 2572 |
+
with self._offline_processor_lock:
|
| 2573 |
+
try:
|
| 2574 |
+
multi_image_max_pixels = media_kwargs.get("multi_image_max_pixels")
|
| 2575 |
+
if multi_image_max_pixels is not None and image_proc is not None:
|
| 2576 |
+
orig_multi_image_max_pixels = getattr(image_proc, "multi_image_max_pixels", None)
|
| 2577 |
+
image_proc.multi_image_max_pixels = multi_image_max_pixels
|
| 2578 |
+
modified_multi_image = True
|
| 2579 |
+
|
| 2580 |
+
video_max_pixels = media_kwargs.get("video_max_pixels")
|
| 2581 |
+
if video_max_pixels is not None and video_proc is not None:
|
| 2582 |
+
orig_video_max_pixels = getattr(video_proc, "video_max_pixels", None)
|
| 2583 |
+
video_proc.video_max_pixels = video_max_pixels
|
| 2584 |
+
modified_video = True
|
| 2585 |
+
|
| 2586 |
+
inputs = processor(**processor_kwargs)
|
| 2587 |
+
finally:
|
| 2588 |
+
if modified_multi_image and image_proc is not None:
|
| 2589 |
+
image_proc.multi_image_max_pixels = orig_multi_image_max_pixels
|
| 2590 |
+
if modified_video and video_proc is not None:
|
| 2591 |
+
video_proc.video_max_pixels = orig_video_max_pixels
|
| 2592 |
+
|
| 2593 |
+
return inputs
|
| 2594 |
+
|
| 2595 |
+
def _offline_move_inputs_to_devices(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 2596 |
+
moved_inputs = dict(inputs)
|
| 2597 |
+
text_device = self.get_input_embeddings().weight.device
|
| 2598 |
+
vision_device = self.visual.patch_embed.proj.weight.device
|
| 2599 |
+
vision_input_keys = {"pixel_values", "grid_thw"}
|
| 2600 |
+
|
| 2601 |
+
for key, value in list(moved_inputs.items()):
|
| 2602 |
+
if not isinstance(value, torch.Tensor):
|
| 2603 |
+
continue
|
| 2604 |
+
|
| 2605 |
+
target_device = vision_device if key in vision_input_keys else text_device
|
| 2606 |
+
moved_value = value.to(target_device)
|
| 2607 |
+
if moved_value.dtype == torch.float32:
|
| 2608 |
+
moved_value = moved_value.to(torch.bfloat16)
|
| 2609 |
+
moved_inputs[key] = moved_value
|
| 2610 |
+
|
| 2611 |
+
return moved_inputs
|
| 2612 |
+
|
| 2613 |
+
@staticmethod
|
| 2614 |
+
def _offline_build_call_kwargs(generate_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
| 2615 |
+
normalized_generate_kwargs = dict(generate_kwargs or {})
|
| 2616 |
+
max_new_tokens = normalized_generate_kwargs.pop("max_new_tokens", 1024)
|
| 2617 |
+
temperature = normalized_generate_kwargs.pop("temperature", 1.0)
|
| 2618 |
+
top_k = normalized_generate_kwargs.pop("top_k", 50)
|
| 2619 |
+
top_p = normalized_generate_kwargs.pop("top_p", 1.0)
|
| 2620 |
+
repetition_penalty = normalized_generate_kwargs.pop("repetition_penalty", 1.0)
|
| 2621 |
+
do_sample = normalized_generate_kwargs.pop("do_sample", False)
|
| 2622 |
+
# ``vision_chunked_length`` was used by the previous modeling file to
|
| 2623 |
+
# shard the visual tower at prefill time. The current forward path
|
| 2624 |
+
# processes the entire vision input in one go, so this flag is
|
| 2625 |
+
# intentionally accepted-and-ignored here for backward compatibility.
|
| 2626 |
+
normalized_generate_kwargs.pop("vision_chunked_length", None)
|
| 2627 |
+
|
| 2628 |
+
if temperature is None:
|
| 2629 |
+
temperature = 1.0
|
| 2630 |
+
if temperature <= 0:
|
| 2631 |
+
temperature = 1.0
|
| 2632 |
+
do_sample = False
|
| 2633 |
+
|
| 2634 |
+
return dict(
|
| 2635 |
+
max_new_tokens=max_new_tokens,
|
| 2636 |
+
temperature=temperature,
|
| 2637 |
+
top_k=top_k,
|
| 2638 |
+
top_p=top_p,
|
| 2639 |
+
repetition_penalty=repetition_penalty,
|
| 2640 |
+
do_sample=do_sample,
|
| 2641 |
+
**normalized_generate_kwargs,
|
| 2642 |
+
)
|
| 2643 |
+
|
| 2644 |
+
def offline_prepare_query_cpu(
|
| 2645 |
+
self,
|
| 2646 |
+
processor,
|
| 2647 |
+
query: Dict[str, Any],
|
| 2648 |
+
session_messages: Optional[List[Dict[str, Any]]] = None,
|
| 2649 |
+
*,
|
| 2650 |
+
padding: bool = False,
|
| 2651 |
+
) -> Dict[str, Any]:
|
| 2652 |
+
current_session = session_messages or []
|
| 2653 |
+
if query.get("reset_session") or query.get("clear_history"):
|
| 2654 |
+
current_session = []
|
| 2655 |
+
|
| 2656 |
+
working_messages = self._offline_build_session_messages(
|
| 2657 |
+
processor,
|
| 2658 |
+
query,
|
| 2659 |
+
current_session,
|
| 2660 |
+
)
|
| 2661 |
+
input_text = self._offline_prepare_input_text(processor, working_messages)
|
| 2662 |
+
all_images, all_videos = self._offline_collect_media(working_messages)
|
| 2663 |
+
media_kwargs = dict(query.get("media_kwargs") or {})
|
| 2664 |
+
processor_kwargs = self._offline_build_processor_kwargs(
|
| 2665 |
+
input_text,
|
| 2666 |
+
all_images,
|
| 2667 |
+
all_videos,
|
| 2668 |
+
media_kwargs,
|
| 2669 |
+
)
|
| 2670 |
+
processor_kwargs["padding"] = padding
|
| 2671 |
+
inputs_cpu = self._offline_run_processor(processor, processor_kwargs, media_kwargs)
|
| 2672 |
+
|
| 2673 |
+
return {
|
| 2674 |
+
"inputs_cpu": inputs_cpu,
|
| 2675 |
+
"input_text": input_text,
|
| 2676 |
+
"working_messages": working_messages,
|
| 2677 |
+
"call_kwargs": self._offline_build_call_kwargs(query.get("generate_kwargs")),
|
| 2678 |
+
}
|
| 2679 |
+
|
| 2680 |
+
def _offline_prepare_inputs(self, processor, query: Dict[str, Any]):
|
| 2681 |
+
prepared = self.offline_prepare_query_cpu(processor, query)
|
| 2682 |
+
inputs = self._offline_move_inputs_to_devices(prepared["inputs_cpu"])
|
| 2683 |
+
return inputs, prepared["input_text"]
|
| 2684 |
+
|
| 2685 |
+
def offline_generate_from_prepared(self, processor, prepared: Dict[str, Any]) -> Dict[str, Any]:
|
| 2686 |
+
inputs = self._offline_move_inputs_to_devices(prepared["inputs_cpu"])
|
| 2687 |
+
input_seq_len = inputs["input_ids"].shape[1]
|
| 2688 |
+
|
| 2689 |
+
with torch.no_grad():
|
| 2690 |
+
outputs = self.generate(
|
| 2691 |
+
**inputs,
|
| 2692 |
+
**prepared["call_kwargs"],
|
| 2693 |
+
)
|
| 2694 |
+
|
| 2695 |
+
generated_tokens = outputs[:, input_seq_len:]
|
| 2696 |
+
decoded_texts = processor.batch_decode(generated_tokens, skip_special_tokens=True)
|
| 2697 |
+
text = decoded_texts[0] if decoded_texts else ""
|
| 2698 |
+
|
| 2699 |
+
return {
|
| 2700 |
+
"text": text,
|
| 2701 |
+
"input_text": prepared["input_text"],
|
| 2702 |
+
"messages": prepared["working_messages"],
|
| 2703 |
+
}
|
| 2704 |
+
|
| 2705 |
+
def _offline_build_session_messages(
|
| 2706 |
+
self,
|
| 2707 |
+
processor,
|
| 2708 |
+
query: Dict[str, Any],
|
| 2709 |
+
session_messages: List[Dict[str, Any]],
|
| 2710 |
+
) -> List[Dict[str, Any]]:
|
| 2711 |
+
has_explicit_messages = bool(query.get("messages"))
|
| 2712 |
+
if has_explicit_messages and not query.get("append_messages_to_session", False):
|
| 2713 |
+
base_messages: List[Dict[str, Any]] = []
|
| 2714 |
+
else:
|
| 2715 |
+
base_messages = [dict(message) for message in session_messages]
|
| 2716 |
+
|
| 2717 |
+
turn_messages = self._offline_prepare_messages(processor, query)
|
| 2718 |
+
has_system_message = any(
|
| 2719 |
+
isinstance(message, dict) and message.get("role") == "system"
|
| 2720 |
+
for message in (base_messages + turn_messages)
|
| 2721 |
+
)
|
| 2722 |
+
|
| 2723 |
+
should_add_system_prompt = (
|
| 2724 |
+
query.get("use_default_system_prompt", False)
|
| 2725 |
+
or query.get("system_prompt") is not None
|
| 2726 |
+
or query.get("system_prompt_type") is not None
|
| 2727 |
+
or query.get("thinking_mode") is not None
|
| 2728 |
+
)
|
| 2729 |
+
|
| 2730 |
+
if not base_messages and not has_system_message and should_add_system_prompt:
|
| 2731 |
+
system_prompt = self._offline_resolve_system_prompt(query, turn_messages)
|
| 2732 |
+
if system_prompt is not None:
|
| 2733 |
+
base_messages.append({"role": "system", "content": system_prompt})
|
| 2734 |
+
|
| 2735 |
+
return base_messages + turn_messages
|
| 2736 |
+
|
| 2737 |
+
@staticmethod
|
| 2738 |
+
def _offline_query_contains_video(query: Dict[str, Any], messages: List[Dict[str, Any]]) -> bool:
|
| 2739 |
+
if query.get("videos"):
|
| 2740 |
+
return True
|
| 2741 |
+
|
| 2742 |
+
for message in messages:
|
| 2743 |
+
content = message.get("content") if isinstance(message, dict) else None
|
| 2744 |
+
if isinstance(content, list) and any(
|
| 2745 |
+
isinstance(item, dict) and (item.get("type") == "video" or "video" in item)
|
| 2746 |
+
for item in content
|
| 2747 |
+
):
|
| 2748 |
+
return True
|
| 2749 |
+
return False
|
| 2750 |
+
|
| 2751 |
+
@staticmethod
|
| 2752 |
+
def _offline_normalize_thinking_mode(value: Optional[str]) -> str:
|
| 2753 |
+
if value is None:
|
| 2754 |
+
return "no_thinking"
|
| 2755 |
+
|
| 2756 |
+
normalized = _OFFLINE_THINKING_MODE_ALIASES.get(str(value).strip().lower())
|
| 2757 |
+
if normalized is None:
|
| 2758 |
+
allowed = ", ".join(sorted(set(_OFFLINE_THINKING_MODE_ALIASES.values())))
|
| 2759 |
+
raise ValueError(f"Unsupported thinking_mode: {value!r}. Supported values: {allowed}")
|
| 2760 |
+
return normalized
|
| 2761 |
+
|
| 2762 |
+
@staticmethod
|
| 2763 |
+
def _offline_normalize_system_prompt_type(value: Optional[str], has_video: bool) -> str:
|
| 2764 |
+
if value is None:
|
| 2765 |
+
return "video" if has_video else "text_image"
|
| 2766 |
+
|
| 2767 |
+
normalized_key = str(value).strip().lower().replace("/", "_").replace(" ", "_")
|
| 2768 |
+
while "__" in normalized_key:
|
| 2769 |
+
normalized_key = normalized_key.replace("__", "_")
|
| 2770 |
+
|
| 2771 |
+
normalized = _OFFLINE_SYSTEM_PROMPT_TYPE_ALIASES.get(normalized_key)
|
| 2772 |
+
if normalized is None:
|
| 2773 |
+
allowed = ", ".join(sorted(set(_OFFLINE_SYSTEM_PROMPT_TYPE_ALIASES.values())))
|
| 2774 |
+
raise ValueError(f"Unsupported system_prompt_type: {value!r}. Supported values: {allowed}")
|
| 2775 |
+
return normalized
|
| 2776 |
+
|
| 2777 |
+
def _offline_resolve_system_prompt(
|
| 2778 |
+
self,
|
| 2779 |
+
query: Dict[str, Any],
|
| 2780 |
+
turn_messages: List[Dict[str, Any]],
|
| 2781 |
+
) -> Optional[str]:
|
| 2782 |
+
explicit_system_prompt = query.get("system_prompt")
|
| 2783 |
+
if explicit_system_prompt is not None:
|
| 2784 |
+
return str(explicit_system_prompt)
|
| 2785 |
+
|
| 2786 |
+
has_video = self._offline_query_contains_video(query, turn_messages)
|
| 2787 |
+
thinking_mode = self._offline_normalize_thinking_mode(query.get("thinking_mode"))
|
| 2788 |
+
system_prompt_type = self._offline_normalize_system_prompt_type(
|
| 2789 |
+
query.get("system_prompt_type"),
|
| 2790 |
+
has_video=has_video,
|
| 2791 |
+
)
|
| 2792 |
+
return _OFFLINE_SYSTEM_PROMPTS[thinking_mode][system_prompt_type]
|
| 2793 |
+
|
| 2794 |
+
@staticmethod
|
| 2795 |
+
def _offline_finalize_session_messages(
|
| 2796 |
+
working_messages: List[Dict[str, Any]],
|
| 2797 |
+
assistant_text: str,
|
| 2798 |
+
) -> List[Dict[str, Any]]:
|
| 2799 |
+
next_messages = [dict(message) for message in working_messages]
|
| 2800 |
+
next_messages.append({"role": "assistant", "content": assistant_text})
|
| 2801 |
+
return next_messages
|
| 2802 |
+
|
| 2803 |
+
def _offline_prepare_generation(self, processor, query: Dict[str, Any]):
|
| 2804 |
+
prepared = self.offline_prepare_query_cpu(processor, query)
|
| 2805 |
+
inputs = self._offline_move_inputs_to_devices(prepared["inputs_cpu"])
|
| 2806 |
+
return inputs, prepared["input_text"], prepared["call_kwargs"]
|
| 2807 |
+
|
| 2808 |
+
@staticmethod
|
| 2809 |
+
def _offline_normalize_shared_mapping(
|
| 2810 |
+
values: List[Dict[str, Any]],
|
| 2811 |
+
mapping_name: str,
|
| 2812 |
+
) -> Dict[str, Any]:
|
| 2813 |
+
normalized_values = [dict(value or {}) for value in values]
|
| 2814 |
+
if not normalized_values:
|
| 2815 |
+
return {}
|
| 2816 |
+
|
| 2817 |
+
all_keys = set()
|
| 2818 |
+
for value in normalized_values:
|
| 2819 |
+
all_keys.update(value.keys())
|
| 2820 |
+
|
| 2821 |
+
merged: Dict[str, Any] = {}
|
| 2822 |
+
mismatched_keys: List[str] = []
|
| 2823 |
+
for key in sorted(all_keys):
|
| 2824 |
+
unique_values = {repr(value.get(key)) for value in normalized_values}
|
| 2825 |
+
if len(unique_values) > 1:
|
| 2826 |
+
mismatched_keys.append(key)
|
| 2827 |
+
else:
|
| 2828 |
+
merged[key] = normalized_values[0].get(key)
|
| 2829 |
+
|
| 2830 |
+
if mismatched_keys:
|
| 2831 |
+
mismatch_text = ", ".join(mismatched_keys)
|
| 2832 |
+
raise ValueError(
|
| 2833 |
+
f"All batch queries must share the same {mapping_name}. "
|
| 2834 |
+
f"Mismatched keys: {mismatch_text}"
|
| 2835 |
+
)
|
| 2836 |
+
return merged
|
| 2837 |
+
|
| 2838 |
+
def _offline_prepare_batch_generation(
|
| 2839 |
+
self,
|
| 2840 |
+
processor,
|
| 2841 |
+
queries: List[Dict[str, Any]],
|
| 2842 |
+
session_states: Optional[List[List[Dict[str, Any]]]] = None,
|
| 2843 |
+
):
|
| 2844 |
+
if not queries:
|
| 2845 |
+
raise ValueError("`queries` must contain at least one query.")
|
| 2846 |
+
|
| 2847 |
+
if session_states is None:
|
| 2848 |
+
session_states = [[] for _ in queries]
|
| 2849 |
+
elif len(session_states) != len(queries):
|
| 2850 |
+
raise ValueError("`session_states` must have the same length as `queries`.")
|
| 2851 |
+
|
| 2852 |
+
working_messages_list: List[List[Dict[str, Any]]] = []
|
| 2853 |
+
input_texts: List[str] = []
|
| 2854 |
+
all_images_per_query: List[List[Any]] = []
|
| 2855 |
+
all_videos_per_query: List[List[Any]] = []
|
| 2856 |
+
|
| 2857 |
+
for query, session_state in zip(queries, session_states):
|
| 2858 |
+
if not isinstance(query, dict):
|
| 2859 |
+
raise TypeError("Each batch query must be a dict.")
|
| 2860 |
+
if query.get("stop_offline_generate"):
|
| 2861 |
+
raise ValueError("`stop_offline_generate` is not supported in offline_batch_generate.")
|
| 2862 |
+
if query.get("stream_output", query.get("stream", False)):
|
| 2863 |
+
raise ValueError("Streaming is not supported in offline_batch_generate.")
|
| 2864 |
+
if query.get("cancel_current_generate") or query.get("stop_generation"):
|
| 2865 |
+
raise ValueError("Cancel / stop controls are not supported in offline_batch_generate.")
|
| 2866 |
+
|
| 2867 |
+
current_session = [] if query.get("reset_session") or query.get("clear_history") else session_state
|
| 2868 |
+
working_messages = self._offline_build_session_messages(
|
| 2869 |
+
processor,
|
| 2870 |
+
query,
|
| 2871 |
+
current_session,
|
| 2872 |
+
)
|
| 2873 |
+
working_messages_list.append(working_messages)
|
| 2874 |
+
input_texts.append(self._offline_prepare_input_text(processor, working_messages))
|
| 2875 |
+
|
| 2876 |
+
all_images, all_videos = self._offline_collect_media(working_messages)
|
| 2877 |
+
all_images_per_query.append(all_images)
|
| 2878 |
+
all_videos_per_query.append(all_videos)
|
| 2879 |
+
|
| 2880 |
+
media_kwargs = self._offline_normalize_shared_mapping(
|
| 2881 |
+
[query.get("media_kwargs") or {} for query in queries],
|
| 2882 |
+
mapping_name="media_kwargs",
|
| 2883 |
+
)
|
| 2884 |
+
processor_kwargs = self._offline_build_processor_kwargs(
|
| 2885 |
+
input_text=input_texts,
|
| 2886 |
+
all_images=[image for images in all_images_per_query for image in images],
|
| 2887 |
+
all_videos=[video for videos in all_videos_per_query for video in videos],
|
| 2888 |
+
media_kwargs=media_kwargs,
|
| 2889 |
+
)
|
| 2890 |
+
processor_kwargs["padding"] = True
|
| 2891 |
+
|
| 2892 |
+
tokenizer = getattr(processor, "tokenizer", None)
|
| 2893 |
+
orig_padding_side = None
|
| 2894 |
+
|
| 2895 |
+
if tokenizer is not None and hasattr(tokenizer, "padding_side"):
|
| 2896 |
+
orig_padding_side = tokenizer.padding_side
|
| 2897 |
+
tokenizer.padding_side = "left"
|
| 2898 |
+
try:
|
| 2899 |
+
inputs = self._offline_run_processor(processor, processor_kwargs, media_kwargs)
|
| 2900 |
+
finally:
|
| 2901 |
+
if tokenizer is not None and orig_padding_side is not None:
|
| 2902 |
+
tokenizer.padding_side = orig_padding_side
|
| 2903 |
+
|
| 2904 |
+
inputs = self._offline_move_inputs_to_devices(inputs)
|
| 2905 |
+
|
| 2906 |
+
generate_kwargs = self._offline_normalize_shared_mapping(
|
| 2907 |
+
[query.get("generate_kwargs") or {} for query in queries],
|
| 2908 |
+
mapping_name="generate_kwargs",
|
| 2909 |
+
)
|
| 2910 |
+
call_kwargs = self._offline_build_call_kwargs(generate_kwargs)
|
| 2911 |
+
return inputs, input_texts, working_messages_list, call_kwargs
|
| 2912 |
+
|
| 2913 |
+
def offline_batch_generate(
|
| 2914 |
+
self,
|
| 2915 |
+
processor,
|
| 2916 |
+
queries: List[Dict[str, Any]],
|
| 2917 |
+
session_states: Optional[List[List[Dict[str, Any]]]] = None,
|
| 2918 |
+
vision_chunked_length: int = 64,
|
| 2919 |
+
) -> Dict[str, Any]:
|
| 2920 |
+
"""
|
| 2921 |
+
Batch offline generation for multiple independent samples.
|
| 2922 |
+
|
| 2923 |
+
This method supports:
|
| 2924 |
+
- batched single-turn generation
|
| 2925 |
+
- batched multi-turn continuation through `session_states`
|
| 2926 |
+
|
| 2927 |
+
It intentionally does not support queue-style controls such as:
|
| 2928 |
+
- `stream_output`
|
| 2929 |
+
- `cancel_current_generate`
|
| 2930 |
+
- `stop_generation`
|
| 2931 |
+
- `stop_offline_generate`
|
| 2932 |
+
"""
|
| 2933 |
+
if not queries:
|
| 2934 |
+
return {"results": [], "session_states": []}
|
| 2935 |
+
|
| 2936 |
+
prepared_queries = [dict(query) for query in queries]
|
| 2937 |
+
for query in prepared_queries:
|
| 2938 |
+
generate_kwargs = query.setdefault("generate_kwargs", {})
|
| 2939 |
+
generate_kwargs.setdefault("vision_chunked_length", vision_chunked_length)
|
| 2940 |
+
if session_states is None:
|
| 2941 |
+
session_states = [[] for _ in prepared_queries]
|
| 2942 |
+
elif len(session_states) != len(prepared_queries):
|
| 2943 |
+
raise ValueError("`session_states` must have the same length as `queries`.")
|
| 2944 |
+
|
| 2945 |
+
tokenizer = getattr(processor, "tokenizer", None)
|
| 2946 |
+
bucketed_indices: Dict[Any, List[int]] = {}
|
| 2947 |
+
for index, (query, session_state) in enumerate(zip(prepared_queries, session_states)):
|
| 2948 |
+
current_session = [] if query.get("reset_session") or query.get("clear_history") else session_state
|
| 2949 |
+
working_messages = self._offline_build_session_messages(processor, query, current_session)
|
| 2950 |
+
input_text = self._offline_prepare_input_text(processor, working_messages)
|
| 2951 |
+
|
| 2952 |
+
if tokenizer is not None:
|
| 2953 |
+
token_ids = tokenizer(input_text, add_special_tokens=False)["input_ids"]
|
| 2954 |
+
bucket_key = len(token_ids)
|
| 2955 |
+
else:
|
| 2956 |
+
bucket_key = len(input_text)
|
| 2957 |
+
bucketed_indices.setdefault(bucket_key, []).append(index)
|
| 2958 |
+
|
| 2959 |
+
results: List[Optional[Dict[str, Any]]] = [None] * len(prepared_queries)
|
| 2960 |
+
next_session_states: List[Optional[List[Dict[str, Any]]]] = [None] * len(prepared_queries)
|
| 2961 |
+
|
| 2962 |
+
for bucket_indices in bucketed_indices.values():
|
| 2963 |
+
bucket_queries = [prepared_queries[index] for index in bucket_indices]
|
| 2964 |
+
bucket_session_states = [session_states[index] for index in bucket_indices]
|
| 2965 |
+
inputs, input_texts, working_messages_list, call_kwargs = self._offline_prepare_batch_generation(
|
| 2966 |
+
processor,
|
| 2967 |
+
bucket_queries,
|
| 2968 |
+
session_states=bucket_session_states,
|
| 2969 |
+
)
|
| 2970 |
+
|
| 2971 |
+
with torch.no_grad():
|
| 2972 |
+
outputs = self.generate(
|
| 2973 |
+
**inputs,
|
| 2974 |
+
**call_kwargs,
|
| 2975 |
+
)
|
| 2976 |
+
|
| 2977 |
+
input_seq_len = inputs["input_ids"].shape[1]
|
| 2978 |
+
generated_tokens = outputs[:, input_seq_len:]
|
| 2979 |
+
decoded_texts = processor.batch_decode(generated_tokens, skip_special_tokens=True)
|
| 2980 |
+
|
| 2981 |
+
for local_index, (query, input_text, working_messages, text) in enumerate(
|
| 2982 |
+
zip(bucket_queries, input_texts, working_messages_list, decoded_texts)
|
| 2983 |
+
):
|
| 2984 |
+
original_index = bucket_indices[local_index]
|
| 2985 |
+
if query.get("persist_session", True):
|
| 2986 |
+
next_session_state = self._offline_finalize_session_messages(working_messages, text)
|
| 2987 |
+
else:
|
| 2988 |
+
next_session_state = working_messages
|
| 2989 |
+
next_session_states[original_index] = next_session_state
|
| 2990 |
+
results[original_index] = {
|
| 2991 |
+
"index": original_index,
|
| 2992 |
+
"text": text,
|
| 2993 |
+
"input_text": input_text,
|
| 2994 |
+
"messages": working_messages,
|
| 2995 |
+
}
|
| 2996 |
+
|
| 2997 |
+
return {
|
| 2998 |
+
"results": [item for item in results if item is not None],
|
| 2999 |
+
"session_states": [item for item in next_session_states if item is not None],
|
| 3000 |
+
}
|
| 3001 |
+
|
| 3002 |
+
def _offline_generate_one(self, processor, query: Dict[str, Any]) -> str:
|
| 3003 |
+
working_messages = self._offline_build_session_messages(processor, query, [])
|
| 3004 |
+
generation_query = dict(query)
|
| 3005 |
+
generation_query["messages"] = working_messages
|
| 3006 |
+
inputs, _, call_kwargs = self._offline_prepare_generation(processor, generation_query)
|
| 3007 |
+
|
| 3008 |
+
with torch.no_grad():
|
| 3009 |
+
outputs = self.generate(
|
| 3010 |
+
**inputs,
|
| 3011 |
+
**call_kwargs,
|
| 3012 |
+
)
|
| 3013 |
+
|
| 3014 |
+
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
|
| 3015 |
+
return processor.decode(new_tokens, skip_special_tokens=True)
|
| 3016 |
+
|
| 3017 |
+
@staticmethod
|
| 3018 |
+
def _offline_capture_processor_attrs(target, overrides: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
| 3019 |
+
if target is None or not overrides:
|
| 3020 |
+
return None
|
| 3021 |
+
return {name: copy.deepcopy(getattr(target, name)) for name in overrides}
|
| 3022 |
+
|
| 3023 |
+
@staticmethod
|
| 3024 |
+
def _offline_apply_processor_attrs(target, overrides: Optional[Dict[str, Any]]) -> None:
|
| 3025 |
+
if target is None or not overrides:
|
| 3026 |
+
return
|
| 3027 |
+
for name, value in overrides.items():
|
| 3028 |
+
setattr(target, name, copy.deepcopy(value))
|
| 3029 |
+
|
| 3030 |
+
@staticmethod
|
| 3031 |
+
def _offline_restore_processor_attrs(target, snapshot: Optional[Dict[str, Any]]) -> None:
|
| 3032 |
+
if target is None or snapshot is None:
|
| 3033 |
+
return
|
| 3034 |
+
for name, value in snapshot.items():
|
| 3035 |
+
setattr(target, name, copy.deepcopy(value))
|
| 3036 |
+
|
| 3037 |
+
def _offline_generate_one_with_processor_overrides(
|
| 3038 |
+
self,
|
| 3039 |
+
processor,
|
| 3040 |
+
query: Dict[str, Any],
|
| 3041 |
+
image_processor_overrides: Optional[Dict[str, Any]] = None,
|
| 3042 |
+
video_processor_overrides: Optional[Dict[str, Any]] = None,
|
| 3043 |
+
) -> str:
|
| 3044 |
+
image_proc = getattr(processor, "image_processor", None)
|
| 3045 |
+
video_proc = getattr(processor, "video_processor", None)
|
| 3046 |
+
image_snapshot = self._offline_capture_processor_attrs(image_proc, image_processor_overrides)
|
| 3047 |
+
video_snapshot = self._offline_capture_processor_attrs(video_proc, video_processor_overrides)
|
| 3048 |
+
|
| 3049 |
+
with self._offline_processor_lock:
|
| 3050 |
+
try:
|
| 3051 |
+
self._offline_apply_processor_attrs(image_proc, image_processor_overrides)
|
| 3052 |
+
self._offline_apply_processor_attrs(video_proc, video_processor_overrides)
|
| 3053 |
+
return self._offline_generate_one(processor, query)
|
| 3054 |
+
finally:
|
| 3055 |
+
self._offline_restore_processor_attrs(image_proc, image_snapshot)
|
| 3056 |
+
self._offline_restore_processor_attrs(video_proc, video_snapshot)
|
| 3057 |
+
|
| 3058 |
+
def offline_image_generate(
|
| 3059 |
+
self,
|
| 3060 |
+
processor,
|
| 3061 |
+
prompt: str,
|
| 3062 |
+
image: Any,
|
| 3063 |
+
*,
|
| 3064 |
+
shortest_edge: int = 4096,
|
| 3065 |
+
longest_edge: int = 16777216,
|
| 3066 |
+
multi_image_max_pixels: int = 201326592,
|
| 3067 |
+
patch_size: int = 16,
|
| 3068 |
+
temporal_patch_size: int = 1,
|
| 3069 |
+
merge_size: int = 2,
|
| 3070 |
+
image_mean: Optional[Union[List[float], Tuple[float, ...]]] = (0.5, 0.5, 0.5),
|
| 3071 |
+
image_std: Optional[Union[List[float], Tuple[float, ...]]] = (0.5, 0.5, 0.5),
|
| 3072 |
+
max_new_tokens: int = 1024,
|
| 3073 |
+
temperature: float = 1.0,
|
| 3074 |
+
top_k: int = 50,
|
| 3075 |
+
top_p: float = 1.0,
|
| 3076 |
+
repetition_penalty: float = 1.0,
|
| 3077 |
+
do_sample: bool = False,
|
| 3078 |
+
vision_chunked_length: int = 64,
|
| 3079 |
+
thinking_mode: Optional[str] = None,
|
| 3080 |
+
system_prompt_type: Optional[str] = None,
|
| 3081 |
+
system_prompt: Optional[str] = None,
|
| 3082 |
+
) -> str:
|
| 3083 |
+
"""
|
| 3084 |
+
Single-image offline generation with explicit image preprocessor defaults.
|
| 3085 |
+
|
| 3086 |
+
The default values mirror `preprocessor_config.json` so README examples can
|
| 3087 |
+
surface the full image preprocessing setup without requiring a batch wrapper.
|
| 3088 |
+
"""
|
| 3089 |
+
query: Dict[str, Any] = {
|
| 3090 |
+
"prompt": prompt,
|
| 3091 |
+
"images": [image],
|
| 3092 |
+
"videos": [],
|
| 3093 |
+
"media_kwargs": {
|
| 3094 |
+
"min_pixels": shortest_edge,
|
| 3095 |
+
"max_pixels": longest_edge,
|
| 3096 |
+
"multi_image_max_pixels": multi_image_max_pixels,
|
| 3097 |
+
},
|
| 3098 |
+
"generate_kwargs": {
|
| 3099 |
+
"max_new_tokens": max_new_tokens,
|
| 3100 |
+
"temperature": temperature,
|
| 3101 |
+
"top_k": top_k,
|
| 3102 |
+
"top_p": top_p,
|
| 3103 |
+
"repetition_penalty": repetition_penalty,
|
| 3104 |
+
"do_sample": do_sample,
|
| 3105 |
+
"vision_chunked_length": vision_chunked_length,
|
| 3106 |
+
},
|
| 3107 |
+
}
|
| 3108 |
+
if thinking_mode is not None:
|
| 3109 |
+
query["thinking_mode"] = thinking_mode
|
| 3110 |
+
if system_prompt_type is not None:
|
| 3111 |
+
query["system_prompt_type"] = system_prompt_type
|
| 3112 |
+
if system_prompt is not None:
|
| 3113 |
+
query["system_prompt"] = system_prompt
|
| 3114 |
+
|
| 3115 |
+
image_processor_overrides = {
|
| 3116 |
+
"size": {"shortest_edge": shortest_edge, "longest_edge": longest_edge},
|
| 3117 |
+
"multi_image_max_pixels": multi_image_max_pixels,
|
| 3118 |
+
"patch_size": patch_size,
|
| 3119 |
+
"temporal_patch_size": temporal_patch_size,
|
| 3120 |
+
"merge_size": merge_size,
|
| 3121 |
+
"image_mean": list(image_mean) if image_mean is not None else None,
|
| 3122 |
+
"image_std": list(image_std) if image_std is not None else None,
|
| 3123 |
+
}
|
| 3124 |
+
return self._offline_generate_one_with_processor_overrides(
|
| 3125 |
+
processor,
|
| 3126 |
+
query,
|
| 3127 |
+
image_processor_overrides=image_processor_overrides,
|
| 3128 |
+
)
|
| 3129 |
+
|
| 3130 |
+
def offline_video_generate(
|
| 3131 |
+
self,
|
| 3132 |
+
processor,
|
| 3133 |
+
prompt: str,
|
| 3134 |
+
video: Any,
|
| 3135 |
+
*,
|
| 3136 |
+
shortest_edge: int = 4096,
|
| 3137 |
+
longest_edge: int = 16777216,
|
| 3138 |
+
video_max_pixels: int = 201326592,
|
| 3139 |
+
patch_size: int = 16,
|
| 3140 |
+
temporal_patch_size: int = 1,
|
| 3141 |
+
merge_size: int = 2,
|
| 3142 |
+
video_fps: float = 1.0,
|
| 3143 |
+
min_frames: int = 1,
|
| 3144 |
+
max_frames: int = 256,
|
| 3145 |
+
num_extract_threads: int = 4,
|
| 3146 |
+
image_mean: Optional[Union[List[float], Tuple[float, ...]]] = (0.5, 0.5, 0.5),
|
| 3147 |
+
image_std: Optional[Union[List[float], Tuple[float, ...]]] = (0.5, 0.5, 0.5),
|
| 3148 |
+
max_new_tokens: int = 1024,
|
| 3149 |
+
temperature: float = 1.0,
|
| 3150 |
+
top_k: int = 50,
|
| 3151 |
+
top_p: float = 1.0,
|
| 3152 |
+
repetition_penalty: float = 1.0,
|
| 3153 |
+
do_sample: bool = False,
|
| 3154 |
+
vision_chunked_length: int = 64,
|
| 3155 |
+
thinking_mode: Optional[str] = None,
|
| 3156 |
+
system_prompt_type: Optional[str] = None,
|
| 3157 |
+
system_prompt: Optional[str] = None,
|
| 3158 |
+
) -> str:
|
| 3159 |
+
"""
|
| 3160 |
+
Single-video offline generation with explicit video preprocessor defaults.
|
| 3161 |
+
|
| 3162 |
+
The default values mirror `video_preprocessor_config.json` so README examples
|
| 3163 |
+
can show a standalone video entry point with the effective preprocessing knobs.
|
| 3164 |
+
"""
|
| 3165 |
+
query: Dict[str, Any] = {
|
| 3166 |
+
"prompt": prompt,
|
| 3167 |
+
"images": [],
|
| 3168 |
+
"videos": [video],
|
| 3169 |
+
"media_kwargs": {
|
| 3170 |
+
"min_pixels": shortest_edge,
|
| 3171 |
+
"max_pixels": longest_edge,
|
| 3172 |
+
"video_max_pixels": video_max_pixels,
|
| 3173 |
+
"video_fps": video_fps,
|
| 3174 |
+
"min_frames": min_frames,
|
| 3175 |
+
"max_frames": max_frames,
|
| 3176 |
+
},
|
| 3177 |
+
"generate_kwargs": {
|
| 3178 |
+
"max_new_tokens": max_new_tokens,
|
| 3179 |
+
"temperature": temperature,
|
| 3180 |
+
"top_k": top_k,
|
| 3181 |
+
"top_p": top_p,
|
| 3182 |
+
"repetition_penalty": repetition_penalty,
|
| 3183 |
+
"do_sample": do_sample,
|
| 3184 |
+
"vision_chunked_length": vision_chunked_length,
|
| 3185 |
+
},
|
| 3186 |
+
}
|
| 3187 |
+
if thinking_mode is not None:
|
| 3188 |
+
query["thinking_mode"] = thinking_mode
|
| 3189 |
+
if system_prompt_type is not None:
|
| 3190 |
+
query["system_prompt_type"] = system_prompt_type
|
| 3191 |
+
if system_prompt is not None:
|
| 3192 |
+
query["system_prompt"] = system_prompt
|
| 3193 |
+
|
| 3194 |
+
video_processor_overrides = {
|
| 3195 |
+
"size": {"shortest_edge": shortest_edge, "longest_edge": longest_edge},
|
| 3196 |
+
"video_max_pixels": video_max_pixels,
|
| 3197 |
+
"patch_size": patch_size,
|
| 3198 |
+
"temporal_patch_size": temporal_patch_size,
|
| 3199 |
+
"merge_size": merge_size,
|
| 3200 |
+
"video_fps": video_fps,
|
| 3201 |
+
"min_frames": min_frames,
|
| 3202 |
+
"max_frames": max_frames,
|
| 3203 |
+
"num_extract_threads": num_extract_threads,
|
| 3204 |
+
"image_mean": list(image_mean) if image_mean is not None else None,
|
| 3205 |
+
"image_std": list(image_std) if image_std is not None else None,
|
| 3206 |
+
}
|
| 3207 |
+
return self._offline_generate_one_with_processor_overrides(
|
| 3208 |
+
processor,
|
| 3209 |
+
query,
|
| 3210 |
+
video_processor_overrides=video_processor_overrides,
|
| 3211 |
+
)
|
| 3212 |
+
|
| 3213 |
+
def offline_generate(
|
| 3214 |
+
self,
|
| 3215 |
+
processor,
|
| 3216 |
+
new_queries: "queue.Queue[dict]",
|
| 3217 |
+
output_text_queue: "queue.Queue[str]",
|
| 3218 |
+
vision_chunked_length: int = 64,
|
| 3219 |
+
) -> None:
|
| 3220 |
+
"""
|
| 3221 |
+
HF-style offline inference wrapper aligned with the previous backend output path.
|
| 3222 |
+
|
| 3223 |
+
This method intentionally reuses the checkpoint's existing processor and
|
| 3224 |
+
`generate()` flow so that outputs stay consistent with the old external
|
| 3225 |
+
backend inference implementation.
|
| 3226 |
+
|
| 3227 |
+
Supported query keys include:
|
| 3228 |
+
- `prompt` / `messages`
|
| 3229 |
+
- `images` / `videos`
|
| 3230 |
+
- `media_kwargs` / `generate_kwargs`
|
| 3231 |
+
- `thinking_mode` (`no_thinking` or `deep_thinking`, plus compatible aliases)
|
| 3232 |
+
- `system_prompt_type` (`text_image` or `video`, plus compatible aliases)
|
| 3233 |
+
- `system_prompt` for an explicit override
|
| 3234 |
+
- `stream_output` / `stream`
|
| 3235 |
+
- `reset_session` / `clear_history`
|
| 3236 |
+
- `cancel_current_generate` / `stop_generation` / `stop_offline_generate`
|
| 3237 |
+
"""
|
| 3238 |
+
buffered_queries: List[Dict[str, Any]] = []
|
| 3239 |
+
session_messages: List[Dict[str, Any]] = []
|
| 3240 |
+
|
| 3241 |
+
while True:
|
| 3242 |
+
if buffered_queries:
|
| 3243 |
+
query = buffered_queries.pop(0)
|
| 3244 |
+
else:
|
| 3245 |
+
query = new_queries.get()
|
| 3246 |
+
if not isinstance(query, dict):
|
| 3247 |
+
continue
|
| 3248 |
+
|
| 3249 |
+
if query.get("stop_offline_generate"):
|
| 3250 |
+
break
|
| 3251 |
+
|
| 3252 |
+
if query.get("reset_session") or query.get("clear_history"):
|
| 3253 |
+
session_messages = []
|
| 3254 |
+
|
| 3255 |
+
try:
|
| 3256 |
+
generate_kwargs = query.setdefault("generate_kwargs", {})
|
| 3257 |
+
generate_kwargs.setdefault("vision_chunked_length", vision_chunked_length)
|
| 3258 |
+
working_messages = self._offline_build_session_messages(
|
| 3259 |
+
processor,
|
| 3260 |
+
query,
|
| 3261 |
+
session_messages,
|
| 3262 |
+
)
|
| 3263 |
+
|
| 3264 |
+
generation_query = dict(query)
|
| 3265 |
+
generation_query["messages"] = working_messages
|
| 3266 |
+
inputs, input_text, call_kwargs = self._offline_prepare_generation(processor, generation_query)
|
| 3267 |
+
|
| 3268 |
+
stream_output = bool(query.get("stream_output", query.get("stream", False)))
|
| 3269 |
+
cancel_event = threading.Event()
|
| 3270 |
+
stopping_criteria = StoppingCriteriaList([_OfflineCancelStoppingCriteria(cancel_event)])
|
| 3271 |
+
generation_state: Dict[str, Any] = {}
|
| 3272 |
+
|
| 3273 |
+
if stream_output:
|
| 3274 |
+
output_text_queue.put("<|round_start|>")
|
| 3275 |
+
streamer = _OfflineQueueStreamer(getattr(processor, "tokenizer", processor), output_text_queue)
|
| 3276 |
+
else:
|
| 3277 |
+
streamer = None
|
| 3278 |
+
|
| 3279 |
+
def _run_generation():
|
| 3280 |
+
try:
|
| 3281 |
+
with torch.no_grad():
|
| 3282 |
+
generation_state["outputs"] = self.generate(
|
| 3283 |
+
**inputs,
|
| 3284 |
+
stopping_criteria=stopping_criteria,
|
| 3285 |
+
streamer=streamer,
|
| 3286 |
+
**call_kwargs,
|
| 3287 |
+
)
|
| 3288 |
+
except Exception as exc:
|
| 3289 |
+
generation_state["exception"] = exc
|
| 3290 |
+
|
| 3291 |
+
worker = threading.Thread(target=_run_generation, daemon=True)
|
| 3292 |
+
worker.start()
|
| 3293 |
+
|
| 3294 |
+
stop_conversation_after_turn = False
|
| 3295 |
+
while worker.is_alive():
|
| 3296 |
+
try:
|
| 3297 |
+
control_query = new_queries.get(timeout=0.1)
|
| 3298 |
+
except queue.Empty:
|
| 3299 |
+
continue
|
| 3300 |
+
|
| 3301 |
+
if not isinstance(control_query, dict):
|
| 3302 |
+
continue
|
| 3303 |
+
|
| 3304 |
+
if control_query.get("cancel_current_generate") or control_query.get("stop_generation"):
|
| 3305 |
+
cancel_event.set()
|
| 3306 |
+
stop_conversation_after_turn = stop_conversation_after_turn or control_query.get("stop_offline_generate", False)
|
| 3307 |
+
continue
|
| 3308 |
+
|
| 3309 |
+
if control_query.get("stop_offline_generate"):
|
| 3310 |
+
cancel_event.set()
|
| 3311 |
+
stop_conversation_after_turn = True
|
| 3312 |
+
continue
|
| 3313 |
+
|
| 3314 |
+
buffered_queries.append(control_query)
|
| 3315 |
+
|
| 3316 |
+
worker.join()
|
| 3317 |
+
was_cancelled = cancel_event.is_set()
|
| 3318 |
+
|
| 3319 |
+
if "exception" in generation_state:
|
| 3320 |
+
raise generation_state["exception"]
|
| 3321 |
+
|
| 3322 |
+
if stream_output and streamer is not None:
|
| 3323 |
+
text = "".join(streamer.collected_chunks)
|
| 3324 |
+
else:
|
| 3325 |
+
outputs = generation_state["outputs"]
|
| 3326 |
+
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
|
| 3327 |
+
text = processor.decode(new_tokens, skip_special_tokens=True)
|
| 3328 |
+
output_text_queue.put(text)
|
| 3329 |
+
|
| 3330 |
+
if query.get("persist_session", True) and (not was_cancelled or query.get("persist_cancelled_turn", False)):
|
| 3331 |
+
session_messages = self._offline_finalize_session_messages(working_messages, text)
|
| 3332 |
+
|
| 3333 |
+
output_text_queue.put("<|round_end|>")
|
| 3334 |
+
|
| 3335 |
+
if stop_conversation_after_turn:
|
| 3336 |
+
break
|
| 3337 |
+
except Exception as exc:
|
| 3338 |
+
output_text_queue.put(f"[ERROR] {exc}")
|
| 3339 |
+
output_text_queue.put("<|round_end|>")
|
| 3340 |
+
|
| 3341 |
|
| 3342 |
__all__ = [
|
| 3343 |
"MossVLVisionModel",
|