| """Model inference — streaming and synchronous generation. |
| |
| Supports two inference paths: |
| - Text-only (MiniCPM5-1B): uses TextIteratorStreamer for real-time streaming |
| - VLM (MiniCPM-V-4.6): uses processor.apply_chat_template() with image support |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import threading |
| from collections.abc import Iterator |
| from typing import Any |
|
|
| from code.config.constants import DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS, MODEL_CONFIGS |
| from code.model.loader import ( |
| get_model, |
| get_tokenizer_or_processor, |
| get_model_status, |
| is_model_loaded, |
| get_current_model_key, |
| get_current_model_type, |
| ) |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def call_model( |
| messages: list[dict[str, Any]], |
| max_new_tokens: int = DEFAULT_MAX_TOKENS, |
| image_url: str | None = None, |
| ) -> Iterator[str]: |
| """Stream model text. Yields progressively longer strings (full text so far). |
| |
| For VLM models, if image_url is provided, it's included in the last user message. |
| """ |
| if not is_model_loaded(): |
| status = get_model_status() |
| yield status["message"] |
| return |
|
|
| model_type = get_current_model_type() |
|
|
| if model_type == "vlm": |
| yield from _call_vlm_model(messages, max_new_tokens, image_url) |
| else: |
| yield from _call_text_model(messages, max_new_tokens) |
|
|
|
|
| def _call_text_model( |
| messages: list[dict[str, Any]], |
| max_new_tokens: int, |
| ) -> Iterator[str]: |
| """Stream text from a text-only model using TextIteratorStreamer.""" |
| model = get_model() |
| tokenizer = get_tokenizer_or_processor() |
|
|
| try: |
| from transformers import TextIteratorStreamer |
| import torch |
|
|
| |
| prompt_parts: list[str] = [] |
| for msg in messages: |
| role = msg.get("role", "user") |
| content = msg.get("content", "") |
| if role == "system": |
| prompt_parts.append(f"System: {content}") |
| elif role == "user": |
| prompt_parts.append(f"User: {content}") |
| elif role == "assistant": |
| prompt_parts.append(f"Assistant: {content}") |
| prompt_parts.append("Assistant:") |
| full_prompt = "\n\n".join(prompt_parts) |
|
|
| |
| inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=4096) |
| if torch.cuda.is_available(): |
| inputs = {k: v.to("cuda") for k, v in inputs.items()} |
|
|
| |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
| generation_kwargs = { |
| **inputs, |
| "streamer": streamer, |
| "max_new_tokens": max_new_tokens, |
| "temperature": DEFAULT_TEMPERATURE, |
| "do_sample": True, |
| "top_p": 0.9, |
| "repetition_penalty": 1.1, |
| "pad_token_id": tokenizer.eos_token_id, |
| } |
|
|
| |
| thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) |
| thread.start() |
|
|
| output = "" |
| for new_text in streamer: |
| output += new_text |
| yield output |
|
|
| thread.join() |
|
|
| except Exception as exc: |
| logger.exception("Error during text model inference") |
| yield f"_Error during generation: {exc}_" |
|
|
|
|
| def _call_vlm_model( |
| messages: list[dict[str, Any]], |
| max_new_tokens: int, |
| image_url: str | None = None, |
| ) -> Iterator[str]: |
| """Stream text from a VLM model with optional image input. |
| |
| Uses processor.apply_chat_template() for proper image+text processing, |
| then generates with streaming via a thread. |
| """ |
| model = get_model() |
| processor = get_tokenizer_or_processor() |
|
|
| try: |
| import torch |
|
|
| |
| vlm_messages = _build_vlm_messages(messages, image_url) |
|
|
| |
| try: |
| inputs = processor.apply_chat_template( |
| vlm_messages, |
| tokenize=True, |
| add_generation_prompt=True, |
| return_dict=True, |
| return_tensors="pt", |
| downsample_mode="16x", |
| max_slice_nums=9, |
| ) |
| except TypeError: |
| |
| inputs = processor.apply_chat_template( |
| vlm_messages, |
| tokenize=True, |
| add_generation_prompt=True, |
| return_dict=True, |
| return_tensors="pt", |
| ) |
|
|
| if torch.cuda.is_available(): |
| inputs = inputs.to("cuda") |
| else: |
| |
| inputs = inputs.to("cpu") |
|
|
| |
| try: |
| from transformers import TextIteratorStreamer |
| streamer = TextIteratorStreamer( |
| processor.tokenizer if hasattr(processor, 'tokenizer') else processor, |
| skip_prompt=True, |
| skip_special_tokens=True, |
| ) |
|
|
| gen_kwargs = { |
| **inputs, |
| "streamer": streamer, |
| "max_new_tokens": max_new_tokens, |
| "temperature": DEFAULT_TEMPERATURE, |
| "do_sample": True, |
| "top_p": 0.9, |
| "repetition_penalty": 1.1, |
| } |
| |
| try: |
| gen_kwargs["downsample_mode"] = "16x" |
| except Exception: |
| pass |
|
|
| |
| if hasattr(processor, 'tokenizer') and hasattr(processor.tokenizer, 'eos_token_id'): |
| gen_kwargs["pad_token_id"] = processor.tokenizer.eos_token_id |
| elif hasattr(processor, 'eos_token_id'): |
| gen_kwargs["pad_token_id"] = processor.eos_token_id |
|
|
| thread = threading.Thread(target=model.generate, kwargs=gen_kwargs) |
| thread.start() |
|
|
| output = "" |
| for new_text in streamer: |
| output += new_text |
| yield output |
|
|
| thread.join() |
|
|
| except Exception as stream_err: |
| |
| logger.warning("Streaming failed for VLM, falling back to sync: %s", stream_err) |
| gen_kwargs = { |
| **inputs, |
| "max_new_tokens": max_new_tokens, |
| "temperature": DEFAULT_TEMPERATURE, |
| "do_sample": True, |
| "top_p": 0.9, |
| } |
| try: |
| gen_kwargs["downsample_mode"] = "16x" |
| except Exception: |
| pass |
|
|
| generated_ids = model.generate(**gen_kwargs) |
| |
| input_len = inputs["input_ids"].shape[1] if hasattr(inputs, "shape") else len(inputs["input_ids"]) |
| generated_ids_trimmed = [ |
| out_ids[len(in_ids):] |
| for in_ids, out_ids in zip(inputs["input_ids"], generated_ids) |
| ] |
| tok = processor.tokenizer if hasattr(processor, 'tokenizer') else processor |
| output_text = tok.batch_decode( |
| generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False |
| ) |
| yield output_text[0] if output_text else "" |
|
|
| except Exception as exc: |
| logger.exception("Error during VLM model inference") |
| yield f"_Error during generation: {exc}_" |
|
|
|
|
| def _build_vlm_messages( |
| messages: list[dict[str, Any]], |
| image_url: str | None = None, |
| ) -> list[dict[str, Any]]: |
| """Build VLM-style messages with image content blocks. |
| |
| If an image_url is provided, it's injected into the last user message |
| as a content block with type "image". |
| """ |
| vlm_messages = [] |
|
|
| for i, msg in enumerate(messages): |
| role = msg.get("role", "user") |
| content = msg.get("content", "") |
|
|
| if role == "system": |
| |
| vlm_messages.append({"role": "system", "content": content}) |
| continue |
|
|
| |
| is_last_user = (i == len(messages) - 1) and role == "user" |
|
|
| if is_last_user and image_url: |
| |
| content_list = [{"type": "image", "url": image_url}] |
| if content.strip(): |
| content_list.append({"type": "text", "text": content}) |
| vlm_messages.append({"role": "user", "content": content_list}) |
| else: |
| vlm_messages.append({"role": role, "content": content}) |
|
|
| return vlm_messages |
|
|
|
|
| def call_model_sync( |
| messages: list[dict[str, Any]], |
| max_new_tokens: int = DEFAULT_MAX_TOKENS, |
| image_url: str | None = None, |
| ) -> str: |
| """Non-streaming model call — returns complete response.""" |
| result = "" |
| for chunk in call_model(messages, max_new_tokens, image_url): |
| result = chunk |
| return result |
|
|