R-Kentaren's picture
Upload folder using huggingface_hub
380204e verified
raw
history blame contribute delete
9.09 kB
"""Model inference — streaming and synchronous generation.
Supports two inference paths:
- Text-only (MiniCPM5-1B): uses TextIteratorStreamer for real-time streaming
- VLM (MiniCPM-V-4.6): uses processor.apply_chat_template() with image support
"""
from __future__ import annotations
import logging
import threading
from collections.abc import Iterator
from typing import Any
from code.config.constants import DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS, MODEL_CONFIGS
from code.model.loader import (
get_model,
get_tokenizer_or_processor,
get_model_status,
is_model_loaded,
get_current_model_key,
get_current_model_type,
)
logger = logging.getLogger(__name__)
def call_model(
messages: list[dict[str, Any]],
max_new_tokens: int = DEFAULT_MAX_TOKENS,
image_url: str | None = None,
) -> Iterator[str]:
"""Stream model text. Yields progressively longer strings (full text so far).
For VLM models, if image_url is provided, it's included in the last user message.
"""
if not is_model_loaded():
status = get_model_status()
yield status["message"]
return
model_type = get_current_model_type()
if model_type == "vlm":
yield from _call_vlm_model(messages, max_new_tokens, image_url)
else:
yield from _call_text_model(messages, max_new_tokens)
def _call_text_model(
messages: list[dict[str, Any]],
max_new_tokens: int,
) -> Iterator[str]:
"""Stream text from a text-only model using TextIteratorStreamer."""
model = get_model()
tokenizer = get_tokenizer_or_processor()
try:
from transformers import TextIteratorStreamer
import torch
# Build the prompt from messages
prompt_parts: list[str] = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
prompt_parts.append(f"System: {content}")
elif role == "user":
prompt_parts.append(f"User: {content}")
elif role == "assistant":
prompt_parts.append(f"Assistant: {content}")
prompt_parts.append("Assistant:")
full_prompt = "\n\n".join(prompt_parts)
# Tokenize
inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=4096)
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Stream generation
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"temperature": DEFAULT_TEMPERATURE,
"do_sample": True,
"top_p": 0.9,
"repetition_penalty": 1.1,
"pad_token_id": tokenizer.eos_token_id,
}
# Run generation in a separate thread
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
output = ""
for new_text in streamer:
output += new_text
yield output
thread.join()
except Exception as exc:
logger.exception("Error during text model inference")
yield f"_Error during generation: {exc}_"
def _call_vlm_model(
messages: list[dict[str, Any]],
max_new_tokens: int,
image_url: str | None = None,
) -> Iterator[str]:
"""Stream text from a VLM model with optional image input.
Uses processor.apply_chat_template() for proper image+text processing,
then generates with streaming via a thread.
"""
model = get_model()
processor = get_tokenizer_or_processor()
try:
import torch
# Build VLM-style messages with image support
vlm_messages = _build_vlm_messages(messages, image_url)
# Apply chat template
try:
inputs = processor.apply_chat_template(
vlm_messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
downsample_mode="16x",
max_slice_nums=9,
)
except TypeError:
# Fallback for older transformers without downsample_mode
inputs = processor.apply_chat_template(
vlm_messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
)
if torch.cuda.is_available():
inputs = inputs.to("cuda")
else:
# Move to CPU explicitly
inputs = inputs.to("cpu")
# Generate with streaming
try:
from transformers import TextIteratorStreamer
streamer = TextIteratorStreamer(
processor.tokenizer if hasattr(processor, 'tokenizer') else processor,
skip_prompt=True,
skip_special_tokens=True,
)
gen_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"temperature": DEFAULT_TEMPERATURE,
"do_sample": True,
"top_p": 0.9,
"repetition_penalty": 1.1,
}
# Add downsample_mode if supported
try:
gen_kwargs["downsample_mode"] = "16x"
except Exception:
pass
# Ensure pad_token_id
if hasattr(processor, 'tokenizer') and hasattr(processor.tokenizer, 'eos_token_id'):
gen_kwargs["pad_token_id"] = processor.tokenizer.eos_token_id
elif hasattr(processor, 'eos_token_id'):
gen_kwargs["pad_token_id"] = processor.eos_token_id
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
output = ""
for new_text in streamer:
output += new_text
yield output
thread.join()
except Exception as stream_err:
# Fallback: non-streaming generation
logger.warning("Streaming failed for VLM, falling back to sync: %s", stream_err)
gen_kwargs = {
**inputs,
"max_new_tokens": max_new_tokens,
"temperature": DEFAULT_TEMPERATURE,
"do_sample": True,
"top_p": 0.9,
}
try:
gen_kwargs["downsample_mode"] = "16x"
except Exception:
pass
generated_ids = model.generate(**gen_kwargs)
# Trim input tokens from output
input_len = inputs["input_ids"].shape[1] if hasattr(inputs, "shape") else len(inputs["input_ids"])
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
]
tok = processor.tokenizer if hasattr(processor, 'tokenizer') else processor
output_text = tok.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
yield output_text[0] if output_text else ""
except Exception as exc:
logger.exception("Error during VLM model inference")
yield f"_Error during generation: {exc}_"
def _build_vlm_messages(
messages: list[dict[str, Any]],
image_url: str | None = None,
) -> list[dict[str, Any]]:
"""Build VLM-style messages with image content blocks.
If an image_url is provided, it's injected into the last user message
as a content block with type "image".
"""
vlm_messages = []
for i, msg in enumerate(messages):
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
# System messages stay as-is for VLM
vlm_messages.append({"role": "system", "content": content})
continue
# For the last user message with an image, use structured content
is_last_user = (i == len(messages) - 1) and role == "user"
if is_last_user and image_url:
# Build content list with image + text
content_list = [{"type": "image", "url": image_url}]
if content.strip():
content_list.append({"type": "text", "text": content})
vlm_messages.append({"role": "user", "content": content_list})
else:
vlm_messages.append({"role": role, "content": content})
return vlm_messages
def call_model_sync(
messages: list[dict[str, Any]],
max_new_tokens: int = DEFAULT_MAX_TOKENS,
image_url: str | None = None,
) -> str:
"""Non-streaming model call — returns complete response."""
result = ""
for chunk in call_model(messages, max_new_tokens, image_url):
result = chunk
return result