"""Model inference — streaming and synchronous generation. Supports two inference paths: - Text-only (MiniCPM5-1B): uses TextIteratorStreamer for real-time streaming - VLM (MiniCPM-V-4.6): uses processor.apply_chat_template() with image support """ from __future__ import annotations import logging import threading from collections.abc import Iterator from typing import Any from code.config.constants import DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS, MODEL_CONFIGS from code.model.loader import ( get_model, get_tokenizer_or_processor, get_model_status, is_model_loaded, get_current_model_key, get_current_model_type, ) logger = logging.getLogger(__name__) def call_model( messages: list[dict[str, Any]], max_new_tokens: int = DEFAULT_MAX_TOKENS, image_url: str | None = None, ) -> Iterator[str]: """Stream model text. Yields progressively longer strings (full text so far). For VLM models, if image_url is provided, it's included in the last user message. """ if not is_model_loaded(): status = get_model_status() yield status["message"] return model_type = get_current_model_type() if model_type == "vlm": yield from _call_vlm_model(messages, max_new_tokens, image_url) else: yield from _call_text_model(messages, max_new_tokens) def _call_text_model( messages: list[dict[str, Any]], max_new_tokens: int, ) -> Iterator[str]: """Stream text from a text-only model using TextIteratorStreamer.""" model = get_model() tokenizer = get_tokenizer_or_processor() try: from transformers import TextIteratorStreamer import torch # Build the prompt from messages prompt_parts: list[str] = [] for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") if role == "system": prompt_parts.append(f"System: {content}") elif role == "user": prompt_parts.append(f"User: {content}") elif role == "assistant": prompt_parts.append(f"Assistant: {content}") prompt_parts.append("Assistant:") full_prompt = "\n\n".join(prompt_parts) # Tokenize inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=4096) if torch.cuda.is_available(): inputs = {k: v.to("cuda") for k, v in inputs.items()} # Stream generation streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { **inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "temperature": DEFAULT_TEMPERATURE, "do_sample": True, "top_p": 0.9, "repetition_penalty": 1.1, "pad_token_id": tokenizer.eos_token_id, } # Run generation in a separate thread thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) thread.start() output = "" for new_text in streamer: output += new_text yield output thread.join() except Exception as exc: logger.exception("Error during text model inference") yield f"_Error during generation: {exc}_" def _call_vlm_model( messages: list[dict[str, Any]], max_new_tokens: int, image_url: str | None = None, ) -> Iterator[str]: """Stream text from a VLM model with optional image input. Uses processor.apply_chat_template() for proper image+text processing, then generates with streaming via a thread. """ model = get_model() processor = get_tokenizer_or_processor() try: import torch # Build VLM-style messages with image support vlm_messages = _build_vlm_messages(messages, image_url) # Apply chat template try: inputs = processor.apply_chat_template( vlm_messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", downsample_mode="16x", max_slice_nums=9, ) except TypeError: # Fallback for older transformers without downsample_mode inputs = processor.apply_chat_template( vlm_messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", ) if torch.cuda.is_available(): inputs = inputs.to("cuda") else: # Move to CPU explicitly inputs = inputs.to("cpu") # Generate with streaming try: from transformers import TextIteratorStreamer streamer = TextIteratorStreamer( processor.tokenizer if hasattr(processor, 'tokenizer') else processor, skip_prompt=True, skip_special_tokens=True, ) gen_kwargs = { **inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "temperature": DEFAULT_TEMPERATURE, "do_sample": True, "top_p": 0.9, "repetition_penalty": 1.1, } # Add downsample_mode if supported try: gen_kwargs["downsample_mode"] = "16x" except Exception: pass # Ensure pad_token_id if hasattr(processor, 'tokenizer') and hasattr(processor.tokenizer, 'eos_token_id'): gen_kwargs["pad_token_id"] = processor.tokenizer.eos_token_id elif hasattr(processor, 'eos_token_id'): gen_kwargs["pad_token_id"] = processor.eos_token_id thread = threading.Thread(target=model.generate, kwargs=gen_kwargs) thread.start() output = "" for new_text in streamer: output += new_text yield output thread.join() except Exception as stream_err: # Fallback: non-streaming generation logger.warning("Streaming failed for VLM, falling back to sync: %s", stream_err) gen_kwargs = { **inputs, "max_new_tokens": max_new_tokens, "temperature": DEFAULT_TEMPERATURE, "do_sample": True, "top_p": 0.9, } try: gen_kwargs["downsample_mode"] = "16x" except Exception: pass generated_ids = model.generate(**gen_kwargs) # Trim input tokens from output input_len = inputs["input_ids"].shape[1] if hasattr(inputs, "shape") else len(inputs["input_ids"]) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids) ] tok = processor.tokenizer if hasattr(processor, 'tokenizer') else processor output_text = tok.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) yield output_text[0] if output_text else "" except Exception as exc: logger.exception("Error during VLM model inference") yield f"_Error during generation: {exc}_" def _build_vlm_messages( messages: list[dict[str, Any]], image_url: str | None = None, ) -> list[dict[str, Any]]: """Build VLM-style messages with image content blocks. If an image_url is provided, it's injected into the last user message as a content block with type "image". """ vlm_messages = [] for i, msg in enumerate(messages): role = msg.get("role", "user") content = msg.get("content", "") if role == "system": # System messages stay as-is for VLM vlm_messages.append({"role": "system", "content": content}) continue # For the last user message with an image, use structured content is_last_user = (i == len(messages) - 1) and role == "user" if is_last_user and image_url: # Build content list with image + text content_list = [{"type": "image", "url": image_url}] if content.strip(): content_list.append({"type": "text", "text": content}) vlm_messages.append({"role": "user", "content": content_list}) else: vlm_messages.append({"role": role, "content": content}) return vlm_messages def call_model_sync( messages: list[dict[str, Any]], max_new_tokens: int = DEFAULT_MAX_TOKENS, image_url: str | None = None, ) -> str: """Non-streaming model call — returns complete response.""" result = "" for chunk in call_model(messages, max_new_tokens, image_url): result = chunk return result