File size: 9,085 Bytes
4412065
 
380204e
 
 
4412065
 
 
 
 
 
 
 
 
380204e
 
 
 
 
 
 
 
 
4412065
 
 
 
 
 
 
380204e
4412065
380204e
4412065
380204e
4412065
 
 
 
 
 
380204e
 
 
 
 
 
 
 
 
 
 
 
 
4412065
380204e
4412065
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380204e
4412065
 
 
380204e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4412065
 
 
380204e
4412065
 
 
380204e
4412065
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
"""Model inference — streaming and synchronous generation.

Supports two inference paths:
- Text-only (MiniCPM5-1B): uses TextIteratorStreamer for real-time streaming
- VLM (MiniCPM-V-4.6): uses processor.apply_chat_template() with image support
"""

from __future__ import annotations

import logging
import threading
from collections.abc import Iterator
from typing import Any

from code.config.constants import DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS, MODEL_CONFIGS
from code.model.loader import (
    get_model,
    get_tokenizer_or_processor,
    get_model_status,
    is_model_loaded,
    get_current_model_key,
    get_current_model_type,
)

logger = logging.getLogger(__name__)


def call_model(
    messages: list[dict[str, Any]],
    max_new_tokens: int = DEFAULT_MAX_TOKENS,
    image_url: str | None = None,
) -> Iterator[str]:
    """Stream model text. Yields progressively longer strings (full text so far).

    For VLM models, if image_url is provided, it's included in the last user message.
    """
    if not is_model_loaded():
        status = get_model_status()
        yield status["message"]
        return

    model_type = get_current_model_type()

    if model_type == "vlm":
        yield from _call_vlm_model(messages, max_new_tokens, image_url)
    else:
        yield from _call_text_model(messages, max_new_tokens)


def _call_text_model(
    messages: list[dict[str, Any]],
    max_new_tokens: int,
) -> Iterator[str]:
    """Stream text from a text-only model using TextIteratorStreamer."""
    model = get_model()
    tokenizer = get_tokenizer_or_processor()

    try:
        from transformers import TextIteratorStreamer
        import torch

        # Build the prompt from messages
        prompt_parts: list[str] = []
        for msg in messages:
            role = msg.get("role", "user")
            content = msg.get("content", "")
            if role == "system":
                prompt_parts.append(f"System: {content}")
            elif role == "user":
                prompt_parts.append(f"User: {content}")
            elif role == "assistant":
                prompt_parts.append(f"Assistant: {content}")
        prompt_parts.append("Assistant:")
        full_prompt = "\n\n".join(prompt_parts)

        # Tokenize
        inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=4096)
        if torch.cuda.is_available():
            inputs = {k: v.to("cuda") for k, v in inputs.items()}

        # Stream generation
        streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

        generation_kwargs = {
            **inputs,
            "streamer": streamer,
            "max_new_tokens": max_new_tokens,
            "temperature": DEFAULT_TEMPERATURE,
            "do_sample": True,
            "top_p": 0.9,
            "repetition_penalty": 1.1,
            "pad_token_id": tokenizer.eos_token_id,
        }

        # Run generation in a separate thread
        thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()

        output = ""
        for new_text in streamer:
            output += new_text
            yield output

        thread.join()

    except Exception as exc:
        logger.exception("Error during text model inference")
        yield f"_Error during generation: {exc}_"


def _call_vlm_model(
    messages: list[dict[str, Any]],
    max_new_tokens: int,
    image_url: str | None = None,
) -> Iterator[str]:
    """Stream text from a VLM model with optional image input.

    Uses processor.apply_chat_template() for proper image+text processing,
    then generates with streaming via a thread.
    """
    model = get_model()
    processor = get_tokenizer_or_processor()

    try:
        import torch

        # Build VLM-style messages with image support
        vlm_messages = _build_vlm_messages(messages, image_url)

        # Apply chat template
        try:
            inputs = processor.apply_chat_template(
                vlm_messages,
                tokenize=True,
                add_generation_prompt=True,
                return_dict=True,
                return_tensors="pt",
                downsample_mode="16x",
                max_slice_nums=9,
            )
        except TypeError:
            # Fallback for older transformers without downsample_mode
            inputs = processor.apply_chat_template(
                vlm_messages,
                tokenize=True,
                add_generation_prompt=True,
                return_dict=True,
                return_tensors="pt",
            )

        if torch.cuda.is_available():
            inputs = inputs.to("cuda")
        else:
            # Move to CPU explicitly
            inputs = inputs.to("cpu")

        # Generate with streaming
        try:
            from transformers import TextIteratorStreamer
            streamer = TextIteratorStreamer(
                processor.tokenizer if hasattr(processor, 'tokenizer') else processor,
                skip_prompt=True,
                skip_special_tokens=True,
            )

            gen_kwargs = {
                **inputs,
                "streamer": streamer,
                "max_new_tokens": max_new_tokens,
                "temperature": DEFAULT_TEMPERATURE,
                "do_sample": True,
                "top_p": 0.9,
                "repetition_penalty": 1.1,
            }
            # Add downsample_mode if supported
            try:
                gen_kwargs["downsample_mode"] = "16x"
            except Exception:
                pass

            # Ensure pad_token_id
            if hasattr(processor, 'tokenizer') and hasattr(processor.tokenizer, 'eos_token_id'):
                gen_kwargs["pad_token_id"] = processor.tokenizer.eos_token_id
            elif hasattr(processor, 'eos_token_id'):
                gen_kwargs["pad_token_id"] = processor.eos_token_id

            thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
            thread.start()

            output = ""
            for new_text in streamer:
                output += new_text
                yield output

            thread.join()

        except Exception as stream_err:
            # Fallback: non-streaming generation
            logger.warning("Streaming failed for VLM, falling back to sync: %s", stream_err)
            gen_kwargs = {
                **inputs,
                "max_new_tokens": max_new_tokens,
                "temperature": DEFAULT_TEMPERATURE,
                "do_sample": True,
                "top_p": 0.9,
            }
            try:
                gen_kwargs["downsample_mode"] = "16x"
            except Exception:
                pass

            generated_ids = model.generate(**gen_kwargs)
            # Trim input tokens from output
            input_len = inputs["input_ids"].shape[1] if hasattr(inputs, "shape") else len(inputs["input_ids"])
            generated_ids_trimmed = [
                out_ids[len(in_ids):]
                for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
            ]
            tok = processor.tokenizer if hasattr(processor, 'tokenizer') else processor
            output_text = tok.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )
            yield output_text[0] if output_text else ""

    except Exception as exc:
        logger.exception("Error during VLM model inference")
        yield f"_Error during generation: {exc}_"


def _build_vlm_messages(
    messages: list[dict[str, Any]],
    image_url: str | None = None,
) -> list[dict[str, Any]]:
    """Build VLM-style messages with image content blocks.

    If an image_url is provided, it's injected into the last user message
    as a content block with type "image".
    """
    vlm_messages = []

    for i, msg in enumerate(messages):
        role = msg.get("role", "user")
        content = msg.get("content", "")

        if role == "system":
            # System messages stay as-is for VLM
            vlm_messages.append({"role": "system", "content": content})
            continue

        # For the last user message with an image, use structured content
        is_last_user = (i == len(messages) - 1) and role == "user"

        if is_last_user and image_url:
            # Build content list with image + text
            content_list = [{"type": "image", "url": image_url}]
            if content.strip():
                content_list.append({"type": "text", "text": content})
            vlm_messages.append({"role": "user", "content": content_list})
        else:
            vlm_messages.append({"role": role, "content": content})

    return vlm_messages


def call_model_sync(
    messages: list[dict[str, Any]],
    max_new_tokens: int = DEFAULT_MAX_TOKENS,
    image_url: str | None = None,
) -> str:
    """Non-streaming model call — returns complete response."""
    result = ""
    for chunk in call_model(messages, max_new_tokens, image_url):
        result = chunk
    return result