| import json |
| import os |
| import unittest |
| from typing import Any, Dict, List, Optional |
|
|
| from transformers import AutoTokenizer |
|
|
| from specforge.data.preprocessing import preprocess_conversations |
| from specforge.data.template import TEMPLATE_REGISTRY |
|
|
|
|
| class TestTemplatePreprocessing(unittest.TestCase): |
| |
| SAVE_REFERENCE = False |
| REF_DIR = os.path.join(os.path.dirname(__file__), "test_references") |
|
|
| @classmethod |
| def setUpClass(cls): |
| """Initialize standard test data""" |
| cls.max_length = 65535 |
| if not os.path.exists(cls.REF_DIR): |
| os.makedirs(cls.REF_DIR) |
|
|
| |
| cls.standard_messages = [ |
| [ |
| {"role": "user", "content": "Who are you?"}, |
| {"role": "assistant", "content": "My name is Qwen2."}, |
| {"role": "user", "content": "How old are you?"}, |
| {"role": "assistant", "content": "11 years old."}, |
| ] |
| ] |
|
|
| |
| cls.gpt_oss_messages = [ |
| [ |
| {"role": "user", "content": "Explain Quantum Physics."}, |
| { |
| "role": "assistant_analysis", |
| "content": "The user wants a summary of quantum physics. I should cover wave-particle duality and uncertainty principle.", |
| }, |
| { |
| "role": "assistant_final", |
| "content": "Quantum physics is the study of matter and energy at the most fundamental level...", |
| }, |
| {"role": "user", "content": "Explain Quantum Physics."}, |
| {"role": "assistant_final", "content": "I'm Qwen"}, |
| ] |
| ] |
|
|
| |
| cls.tool_use_messages = [ |
| [ |
| { |
| "role": "user", |
| "content": "What's the weather like in Beijing today?", |
| }, |
| { |
| "role": "assistant", |
| "content": "I'll check the current weather in Beijing for you.", |
| }, |
| { |
| "role": "tool", |
| "content": '{"location": "Beijing", "temperature": 22, "condition": "Sunny"}', |
| }, |
| { |
| "role": "assistant", |
| "content": "The current weather in Beijing is sunny with a temperature of 22°C.", |
| }, |
| { |
| "role": "tool", |
| "content": '{"unit": "Celsius", "forecast": "Clear skies all day."}', |
| }, |
| { |
| "role": "tool", |
| "content": '{"unit": "Celsius", "forecast": "Clear skies all day."}', |
| }, |
| { |
| "role": "user", |
| "content": "Great! Can you also tell me if it will rain tomorrow?", |
| }, |
| { |
| "role": "assistant", |
| "content": "Based on the forecast, there will be no rain tomorrow—expect clear skies all day.", |
| }, |
| ] |
| ] |
|
|
| def _get_ref_path(self, template_key: str, message_label: str = "standard"): |
| return os.path.join(self.REF_DIR, f"{template_key}_{message_label}_ref.json") |
|
|
| def _run_template_test( |
| self, |
| model_name: str, |
| template_key: str, |
| messages: Optional[List[List[Dict[str, str]]]] = None, |
| ): |
| """Encapsulate common test and regression validation logic""" |
|
|
| |
| target_messages = messages if messages is not None else self.standard_messages |
| message_label = None |
| if target_messages == self.standard_messages: |
| message_label = "standard" |
| elif target_messages == self.gpt_oss_messages: |
| message_label = "gpt-oss" |
| elif target_messages == self.tool_use_messages: |
| message_label = "tool-use" |
| else: |
| raise ValueError("Invalid message set") |
| print(f"\n>>> Running: {template_key} ({model_name}) {message_label}") |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
| chat_template = TEMPLATE_REGISTRY.get(template_key) |
|
|
| |
| res = preprocess_conversations( |
| tokenizer, target_messages, chat_template, self.max_length |
| ) |
| |
| current_data = { |
| "input_ids": res["input_ids"][0][0].tolist(), |
| "loss_mask": res["loss_mask"][0][0].tolist(), |
| } |
|
|
| ref_path = self._get_ref_path(template_key, message_label) |
| |
| if self.SAVE_REFERENCE: |
| with open(ref_path, "w", encoding="utf-8") as f: |
| json.dump(current_data, f) |
| print(f" [INFO] Reference saved to {ref_path}") |
| else: |
| if not os.path.exists(ref_path): |
| self.fail( |
| f"Reference file not found for {template_key}. Set SAVE_REFERENCE=True." |
| ) |
|
|
| with open(ref_path, "r", encoding="utf-8") as f: |
| ref_data = json.load(f) |
|
|
| self.assertListEqual(current_data["input_ids"], ref_data["input_ids"]) |
| self.assertListEqual(current_data["loss_mask"], ref_data["loss_mask"]) |
| print(f" [PASS] Regression test passed for {template_key}") |
|
|
| |
| self.debug_show_loss_mask(res, tokenizer) |
|
|
| @staticmethod |
| def debug_show_loss_mask(res: Dict[str, Any], tokenizer: AutoTokenizer): |
| input_ids = res["input_ids"][0][0].tolist() |
| loss_mask = res["loss_mask"][0][0].tolist() |
| RED, RESET = "\033[91m", "\033[0m" |
| print("-" * 30) |
| for tid, m in zip(input_ids, loss_mask): |
| txt = tokenizer.decode([tid]) |
| txt = txt.replace("\n", "\\n") |
| print(f"{RED if m == 1 else ''}{txt}{RESET}", end="") |
| print("\n" + "-" * 30) |
|
|
| |
|
|
| def test_deepseek(self): |
| self._run_template_test("deepseek-ai/DeepSeek-V3", "deepseek-v3") |
|
|
| def test_deepseek_v32(self): |
| self._run_template_test("deepseek-ai/DeepSeek-V3.2", "deepseek-v32") |
|
|
| def test_qwen3_thinking(self): |
| self._run_template_test("Qwen/Qwen3-0.6B", "qwen3-thinking") |
|
|
| def test_qwen3_instruct(self): |
| self._run_template_test("Qwen/Qwen3-0.6B", "qwen3-instruct") |
|
|
| def test_qwen3_next_instruct(self): |
| self._run_template_test("Qwen/Qwen3-Next-80B-A3B-Instruct", "qwen") |
|
|
| def test_kimi_k2_thinking(self): |
| self._run_template_test("moonshotai/Kimi-K2-Thinking", "kimi-k2-thinking") |
|
|
| def test_kimi_k2_instruct(self): |
| self._run_template_test("moonshotai/Kimi-K2-Instruct", "kimi-k2-instruct") |
|
|
| def test_qwen3_next_thinking(self): |
| self._run_template_test( |
| "Qwen/Qwen3-Next-80B-A3B-Thinking", "qwen3-next-thinking" |
| ) |
|
|
| def test_gpt_oss(self): |
| self._run_template_test( |
| "openai/gpt-oss-120b", "gpt-oss", messages=self.gpt_oss_messages |
| ) |
|
|
| def test_ling_flash_2_0(self): |
| self._run_template_test("inclusionAI/Ling-flash-2.0", "ling-flash-2.0") |
|
|
| def test_qwen3_instruct_with_tools(self): |
| self._run_template_test( |
| "Qwen/Qwen3-0.6B", "qwen3-instruct", messages=self.tool_use_messages |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|