| import re |
| from typing import Dict, List, Optional, Sequence, Union |
|
|
| from vllm import LLM, SamplingParams |
| from vllm.entrypoints.chat_utils import ( |
| ChatCompletionMessageParam, |
| apply_hf_chat_template, |
| apply_mistral_chat_template, |
| parse_chat_messages, |
| ) |
| from vllm.inputs import PromptInputs, TextPrompt |
| from vllm.lora.request import LoRARequest |
| from vllm.outputs import RequestOutput |
| from vllm.transformers_utils.tokenizer import MistralTokenizer |
| from vllm.utils import is_list_of |
|
|
|
|
| _TAIL_WS_RE = re.compile(r"(?:\r?\n|\s)+$") |
|
|
| def needs_newline(text: str) -> bool: |
| """Return True when *text* does NOT already end with whitespace/newline.""" |
| return _TAIL_WS_RE.search(text[-8:]) is None |
|
|
|
|
| def add_prefix(prompt: str, prefix: str, eos_token: str) -> str: |
| """Insert *prefix* before the first generated token. |
| |
| Keeps EOS token at the very end if the template already appended it. |
| """ |
| if prompt.endswith(eos_token): |
| return prompt[:-len(eos_token)] + prefix + eos_token |
| return prompt + prefix |
|
|
|
|
| class PrefixLLM(LLM): |
| """vLLM LLM subclass that conditionally prepends *trigger_word*.""" |
|
|
| def route_chat( |
| self, |
| messages: Union[ |
| List[ChatCompletionMessageParam], |
| List[List[ChatCompletionMessageParam]], |
| ], |
| sampling_params_route: Optional[Union[SamplingParams, |
| List[SamplingParams]]] = None, |
| sampling_params_force_think: Optional[Union[SamplingParams, |
| List[SamplingParams]]] = None, |
| use_tqdm: bool = True, |
| lora_request: Optional[LoRARequest] = None, |
| chat_template: Optional[str] = None, |
| add_generation_prompt: bool = True, |
| tools: Optional[List[Dict[str, any]]] = None, |
| *, |
| trigger_word: Optional[str] = None, |
| ) -> List[RequestOutput]: |
| """Drop-in replacement for `LLM.chat` with one extra keyword: |
| |
| Parameters |
| ---------- |
| trigger_word : str | None, default None |
| The prefix to inject. If ``None`` → no prefix injection. |
| """ |
|
|
| tokenizer = self.get_tokenizer() |
| model_config = self.llm_engine.get_model_config() |
| eos_token = tokenizer.eos_token |
|
|
| orig_prompts: List[Union[TokensPrompt, TextPrompt]] = [] |
| pref_prompts: List[Union[TokensPrompt, TextPrompt]] = [] |
| mm_payloads: List[Optional[Dict[str, Any]]] = [] |
|
|
| list_of_messages: List[List[ChatCompletionMessageParam]] |
|
|
| |
| if is_list_of(messages, list): |
| |
| list_of_messages = messages |
| else: |
| |
| list_of_messages = [messages] |
|
|
| prompts: List[Union[TokensPrompt, TextPrompt]] = [] |
|
|
| for msgs in list_of_messages: |
| |
| if isinstance(tokenizer, MistralTokenizer): |
| prompt_data: Union[str, List[int]] = apply_mistral_chat_template( |
| tokenizer, |
| messages=msgs, |
| chat_template=chat_template, |
| add_generation_prompt=add_generation_prompt, |
| tools=tools, |
| ) |
| mm_data = None |
| else: |
| conversation, mm_data = parse_chat_messages(msgs, model_config, tokenizer) |
| prompt_data = apply_hf_chat_template( |
| tokenizer, |
| conversation=conversation, |
| chat_template=chat_template, |
| add_generation_prompt=add_generation_prompt, |
| tools=tools, |
| ) |
|
|
| if is_list_of(prompt_data, int): |
| raise NotImplementedError |
| else: |
| orig_prompt = TextPrompt(prompt=prompt_data) |
|
|
| if trigger_word is None: |
| raise ValueError("trigger_word must be provided when using force_think logic") |
|
|
| need_nl = needs_newline(prompt_data) |
| prefix = trigger_word + ("\n" if need_nl else "") |
| pref_txt = add_prefix(prompt_data, prefix, eos_token) |
| pref_prompt = TextPrompt(prompt=pref_txt) |
| |
| if mm_data is not None: |
| orig_prompt["multi_modal_data"] = mm_data |
| pref_prompt["multi_modal_data"] = copy.deepcopy(mm_data) |
| |
| orig_prompts.append(orig_prompt) |
| pref_prompts.append(pref_prompt) |
|
|
| results = self.generate( |
| orig_prompts, |
| sampling_params=sampling_params_route, |
| use_tqdm=use_tqdm, |
| lora_request=lora_request, |
| ) |
|
|
| need_force = [i for i, out in enumerate(results) if "<specialLong>" in out.outputs[0].text[:100]] |
| |
|
|
| if len(need_force) == 0: |
| return results |
|
|
| prompts_force = [pref_prompts[i] for i in need_force] |
|
|
| results_force = self.generate( |
| prompts_force, |
| sampling_params=sampling_params_force_think, |
| use_tqdm=use_tqdm, |
| lora_request=lora_request, |
| ) |
| |
| for idx, new_out in zip(need_force, results_force): |
| results[idx] = new_out |
|
|
| return results |