Spaces:
Running
Running
| """Model inference — streaming and synchronous generation. | |
| Supports two inference paths: | |
| - Text-only (MiniCPM5-1B): uses TextIteratorStreamer for real-time streaming | |
| - VLM (MiniCPM-V-4.6): uses processor.apply_chat_template() with image support | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import threading | |
| from collections.abc import Iterator | |
| from typing import Any | |
| from code.config.constants import DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS, MODEL_CONFIGS | |
| from code.model.loader import ( | |
| get_model, | |
| get_tokenizer_or_processor, | |
| get_model_status, | |
| is_model_loaded, | |
| get_current_model_key, | |
| get_current_model_type, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def call_model( | |
| messages: list[dict[str, Any]], | |
| max_new_tokens: int = DEFAULT_MAX_TOKENS, | |
| image_url: str | None = None, | |
| ) -> Iterator[str]: | |
| """Stream model text. Yields progressively longer strings (full text so far). | |
| For VLM models, if image_url is provided, it's included in the last user message. | |
| """ | |
| if not is_model_loaded(): | |
| status = get_model_status() | |
| yield status["message"] | |
| return | |
| model_type = get_current_model_type() | |
| if model_type == "vlm": | |
| yield from _call_vlm_model(messages, max_new_tokens, image_url) | |
| else: | |
| yield from _call_text_model(messages, max_new_tokens) | |
| def _call_text_model( | |
| messages: list[dict[str, Any]], | |
| max_new_tokens: int, | |
| ) -> Iterator[str]: | |
| """Stream text from a text-only model using TextIteratorStreamer.""" | |
| model = get_model() | |
| tokenizer = get_tokenizer_or_processor() | |
| try: | |
| from transformers import TextIteratorStreamer | |
| import torch | |
| # Build the prompt from messages | |
| prompt_parts: list[str] = [] | |
| for msg in messages: | |
| role = msg.get("role", "user") | |
| content = msg.get("content", "") | |
| if role == "system": | |
| prompt_parts.append(f"System: {content}") | |
| elif role == "user": | |
| prompt_parts.append(f"User: {content}") | |
| elif role == "assistant": | |
| prompt_parts.append(f"Assistant: {content}") | |
| prompt_parts.append("Assistant:") | |
| full_prompt = "\n\n".join(prompt_parts) | |
| # Tokenize | |
| inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=4096) | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
| # Stream generation | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = { | |
| **inputs, | |
| "streamer": streamer, | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": DEFAULT_TEMPERATURE, | |
| "do_sample": True, | |
| "top_p": 0.9, | |
| "repetition_penalty": 1.1, | |
| "pad_token_id": tokenizer.eos_token_id, | |
| } | |
| # Run generation in a separate thread | |
| thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| output = "" | |
| for new_text in streamer: | |
| output += new_text | |
| yield output | |
| thread.join() | |
| except Exception as exc: | |
| logger.exception("Error during text model inference") | |
| yield f"_Error during generation: {exc}_" | |
| def _call_vlm_model( | |
| messages: list[dict[str, Any]], | |
| max_new_tokens: int, | |
| image_url: str | None = None, | |
| ) -> Iterator[str]: | |
| """Stream text from a VLM model with optional image input. | |
| Uses processor.apply_chat_template() for proper image+text processing, | |
| then generates with streaming via a thread. | |
| """ | |
| model = get_model() | |
| processor = get_tokenizer_or_processor() | |
| try: | |
| import torch | |
| # Build VLM-style messages with image support | |
| vlm_messages = _build_vlm_messages(messages, image_url) | |
| # Apply chat template | |
| try: | |
| inputs = processor.apply_chat_template( | |
| vlm_messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| downsample_mode="16x", | |
| max_slice_nums=9, | |
| ) | |
| except TypeError: | |
| # Fallback for older transformers without downsample_mode | |
| inputs = processor.apply_chat_template( | |
| vlm_messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ) | |
| if torch.cuda.is_available(): | |
| inputs = inputs.to("cuda") | |
| else: | |
| # Move to CPU explicitly | |
| inputs = inputs.to("cpu") | |
| # Generate with streaming | |
| try: | |
| from transformers import TextIteratorStreamer | |
| streamer = TextIteratorStreamer( | |
| processor.tokenizer if hasattr(processor, 'tokenizer') else processor, | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| ) | |
| gen_kwargs = { | |
| **inputs, | |
| "streamer": streamer, | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": DEFAULT_TEMPERATURE, | |
| "do_sample": True, | |
| "top_p": 0.9, | |
| "repetition_penalty": 1.1, | |
| } | |
| # Add downsample_mode if supported | |
| try: | |
| gen_kwargs["downsample_mode"] = "16x" | |
| except Exception: | |
| pass | |
| # Ensure pad_token_id | |
| if hasattr(processor, 'tokenizer') and hasattr(processor.tokenizer, 'eos_token_id'): | |
| gen_kwargs["pad_token_id"] = processor.tokenizer.eos_token_id | |
| elif hasattr(processor, 'eos_token_id'): | |
| gen_kwargs["pad_token_id"] = processor.eos_token_id | |
| thread = threading.Thread(target=model.generate, kwargs=gen_kwargs) | |
| thread.start() | |
| output = "" | |
| for new_text in streamer: | |
| output += new_text | |
| yield output | |
| thread.join() | |
| except Exception as stream_err: | |
| # Fallback: non-streaming generation | |
| logger.warning("Streaming failed for VLM, falling back to sync: %s", stream_err) | |
| gen_kwargs = { | |
| **inputs, | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": DEFAULT_TEMPERATURE, | |
| "do_sample": True, | |
| "top_p": 0.9, | |
| } | |
| try: | |
| gen_kwargs["downsample_mode"] = "16x" | |
| except Exception: | |
| pass | |
| generated_ids = model.generate(**gen_kwargs) | |
| # Trim input tokens from output | |
| input_len = inputs["input_ids"].shape[1] if hasattr(inputs, "shape") else len(inputs["input_ids"]) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids):] | |
| for in_ids, out_ids in zip(inputs["input_ids"], generated_ids) | |
| ] | |
| tok = processor.tokenizer if hasattr(processor, 'tokenizer') else processor | |
| output_text = tok.batch_decode( | |
| generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| ) | |
| yield output_text[0] if output_text else "" | |
| except Exception as exc: | |
| logger.exception("Error during VLM model inference") | |
| yield f"_Error during generation: {exc}_" | |
| def _build_vlm_messages( | |
| messages: list[dict[str, Any]], | |
| image_url: str | None = None, | |
| ) -> list[dict[str, Any]]: | |
| """Build VLM-style messages with image content blocks. | |
| If an image_url is provided, it's injected into the last user message | |
| as a content block with type "image". | |
| """ | |
| vlm_messages = [] | |
| for i, msg in enumerate(messages): | |
| role = msg.get("role", "user") | |
| content = msg.get("content", "") | |
| if role == "system": | |
| # System messages stay as-is for VLM | |
| vlm_messages.append({"role": "system", "content": content}) | |
| continue | |
| # For the last user message with an image, use structured content | |
| is_last_user = (i == len(messages) - 1) and role == "user" | |
| if is_last_user and image_url: | |
| # Build content list with image + text | |
| content_list = [{"type": "image", "url": image_url}] | |
| if content.strip(): | |
| content_list.append({"type": "text", "text": content}) | |
| vlm_messages.append({"role": "user", "content": content_list}) | |
| else: | |
| vlm_messages.append({"role": role, "content": content}) | |
| return vlm_messages | |
| def call_model_sync( | |
| messages: list[dict[str, Any]], | |
| max_new_tokens: int = DEFAULT_MAX_TOKENS, | |
| image_url: str | None = None, | |
| ) -> str: | |
| """Non-streaming model call — returns complete response.""" | |
| result = "" | |
| for chunk in call_model(messages, max_new_tokens, image_url): | |
| result = chunk | |
| return result | |