Hanrui / progress /github /SpecForge /tests /test_data /test_parsers.py
Lekr0's picture
Add files using upload-large-folder tool
62dca4c verified
import json
import os
import unittest
from typing import Any, Dict, List, Optional
from transformers import AutoTokenizer
from specforge.data.preprocessing import preprocess_conversations
from specforge.data.template import TEMPLATE_REGISTRY
class TestTemplatePreprocessing(unittest.TestCase):
# Configuration section
SAVE_REFERENCE = False
REF_DIR = os.path.join(os.path.dirname(__file__), "test_references")
@classmethod
def setUpClass(cls):
"""Initialize standard test data"""
cls.max_length = 65535
if not os.path.exists(cls.REF_DIR):
os.makedirs(cls.REF_DIR)
# 1. General model test data (Qwen, DeepSeek, etc.)
cls.standard_messages = [
[
{"role": "user", "content": "Who are you?"},
{"role": "assistant", "content": "My name is Qwen2."},
{"role": "user", "content": "How old are you?"},
{"role": "assistant", "content": "11 years old."},
]
]
# 2. GPT-OSS Dedicated Test Data (Including Analysis and Final Channel)
cls.gpt_oss_messages = [
[
{"role": "user", "content": "Explain Quantum Physics."},
{
"role": "assistant_analysis",
"content": "The user wants a summary of quantum physics. I should cover wave-particle duality and uncertainty principle.",
},
{
"role": "assistant_final",
"content": "Quantum physics is the study of matter and energy at the most fundamental level...",
},
{"role": "user", "content": "Explain Quantum Physics."},
{"role": "assistant_final", "content": "I'm Qwen"},
]
]
# 3. Tool-Use Test Data
cls.tool_use_messages = [
[
{
"role": "user",
"content": "What's the weather like in Beijing today?",
},
{
"role": "assistant",
"content": "I'll check the current weather in Beijing for you.",
},
{
"role": "tool",
"content": '{"location": "Beijing", "temperature": 22, "condition": "Sunny"}',
},
{
"role": "assistant",
"content": "The current weather in Beijing is sunny with a temperature of 22°C.",
},
{
"role": "tool",
"content": '{"unit": "Celsius", "forecast": "Clear skies all day."}',
},
{
"role": "tool",
"content": '{"unit": "Celsius", "forecast": "Clear skies all day."}',
},
{
"role": "user",
"content": "Great! Can you also tell me if it will rain tomorrow?",
},
{
"role": "assistant",
"content": "Based on the forecast, there will be no rain tomorrow—expect clear skies all day.",
},
]
]
def _get_ref_path(self, template_key: str, message_label: str = "standard"):
return os.path.join(self.REF_DIR, f"{template_key}_{message_label}_ref.json")
def _run_template_test(
self,
model_name: str,
template_key: str,
messages: Optional[List[List[Dict[str, str]]]] = None,
):
"""Encapsulate common test and regression validation logic"""
# Use the input message or the default standard message.
target_messages = messages if messages is not None else self.standard_messages
message_label = None
if target_messages == self.standard_messages:
message_label = "standard"
elif target_messages == self.gpt_oss_messages:
message_label = "gpt-oss"
elif target_messages == self.tool_use_messages:
message_label = "tool-use"
else:
raise ValueError("Invalid message set")
print(f"\n>>> Running: {template_key} ({model_name}) {message_label}")
# 1. Initialize tokenizer and template
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
chat_template = TEMPLATE_REGISTRY.get(template_key)
# 2. Preprocess conversations
res = preprocess_conversations(
tokenizer, target_messages, chat_template, self.max_length
)
# Extract current result
current_data = {
"input_ids": res["input_ids"][0][0].tolist(),
"loss_mask": res["loss_mask"][0][0].tolist(),
}
ref_path = self._get_ref_path(template_key, message_label)
# 3. Branch logic: update reference or perform comparison
if self.SAVE_REFERENCE:
with open(ref_path, "w", encoding="utf-8") as f:
json.dump(current_data, f)
print(f" [INFO] Reference saved to {ref_path}")
else:
if not os.path.exists(ref_path):
self.fail(
f"Reference file not found for {template_key}. Set SAVE_REFERENCE=True."
)
with open(ref_path, "r", encoding="utf-8") as f:
ref_data = json.load(f)
self.assertListEqual(current_data["input_ids"], ref_data["input_ids"])
self.assertListEqual(current_data["loss_mask"], ref_data["loss_mask"])
print(f" [PASS] Regression test passed for {template_key}")
# 4. Debug output
self.debug_show_loss_mask(res, tokenizer)
@staticmethod
def debug_show_loss_mask(res: Dict[str, Any], tokenizer: AutoTokenizer):
input_ids = res["input_ids"][0][0].tolist()
loss_mask = res["loss_mask"][0][0].tolist()
RED, RESET = "\033[91m", "\033[0m"
print("-" * 30)
for tid, m in zip(input_ids, loss_mask):
txt = tokenizer.decode([tid])
txt = txt.replace("\n", "\\n")
print(f"{RED if m == 1 else ''}{txt}{RESET}", end="")
print("\n" + "-" * 30)
## The Following are tests. Each test corresponds to a specific model and template.
def test_deepseek(self):
self._run_template_test("deepseek-ai/DeepSeek-V3", "deepseek-v3")
def test_deepseek_v32(self):
self._run_template_test("deepseek-ai/DeepSeek-V3.2", "deepseek-v32")
def test_qwen3_thinking(self):
self._run_template_test("Qwen/Qwen3-0.6B", "qwen3-thinking")
def test_qwen3_instruct(self):
self._run_template_test("Qwen/Qwen3-0.6B", "qwen3-instruct")
def test_qwen3_next_instruct(self):
self._run_template_test("Qwen/Qwen3-Next-80B-A3B-Instruct", "qwen")
def test_kimi_k2_thinking(self):
self._run_template_test("moonshotai/Kimi-K2-Thinking", "kimi-k2-thinking")
def test_kimi_k2_instruct(self):
self._run_template_test("moonshotai/Kimi-K2-Instruct", "kimi-k2-instruct")
def test_qwen3_next_thinking(self):
self._run_template_test(
"Qwen/Qwen3-Next-80B-A3B-Thinking", "qwen3-next-thinking"
)
def test_gpt_oss(self):
self._run_template_test(
"openai/gpt-oss-120b", "gpt-oss", messages=self.gpt_oss_messages
)
def test_ling_flash_2_0(self):
self._run_template_test("inclusionAI/Ling-flash-2.0", "ling-flash-2.0")
def test_qwen3_instruct_with_tools(self):
self._run_template_test(
"Qwen/Qwen3-0.6B", "qwen3-instruct", messages=self.tool_use_messages
)
if __name__ == "__main__":
unittest.main()