| | |
| | """ |
| | Data preprocessing script. |
| | |
| | Convert the generated dataset into a format directly consumable by SFTTrainer. |
| | FunctionGemma expects a specific chat template structure. |
| | |
| | Usage: |
| | python -m src.prepare_dataset --input ./data/training_data.json --output ./data/prepared_dataset.json |
| | """ |
| |
|
| | import json |
| | import argparse |
| | from pathlib import Path |
| | from typing import List, Dict, Any |
| |
|
| |
|
| | PROJECT_ROOT = Path(__file__).resolve().parent.parent |
| | DEFAULT_INPUT = PROJECT_ROOT / "data" / "training_data.json" |
| | DEFAULT_OUTPUT = PROJECT_ROOT / "data" / "prepared_dataset.json" |
| |
|
| |
|
| | def convert_tool_calls_to_text(tool_calls: List[Dict]) -> str: |
| | """Convert tool_calls into plain text (FunctionGemma format).""" |
| | if not tool_calls: |
| | return "" |
| | |
| | result_parts = [] |
| | for tc in tool_calls: |
| | func = tc.get("function", {}) |
| | name = func.get("name", "") |
| | args = func.get("arguments", {}) |
| | |
| | |
| | args_str = json.dumps(args, ensure_ascii=False) |
| | result_parts.append(f"{name}({args_str})") |
| | |
| | return "\n".join(result_parts) |
| |
|
| |
|
| | def convert_messages_for_sft(messages: List[Dict], tools: List[Dict] = None) -> List[Dict]: |
| | """ |
| | Convert message format for SFTTrainer. |
| | |
| | Input: |
| | [ |
| | {"role": "developer", "content": "..."}, |
| | {"role": "user", "content": "..."}, |
| | {"role": "assistant", "tool_calls": [...]} or {"role": "assistant", "content": "..."} |
| | ] |
| | |
| | Output: |
| | [ |
| | {"role": "system", "content": "..."}, # developer -> system |
| | {"role": "user", "content": "..."}, |
| | {"role": "assistant", "content": "..."} # tool_calls flattened to text |
| | ] |
| | """ |
| | converted = [] |
| | |
| | |
| | tools_description = "" |
| | if tools: |
| | tools_desc_parts = [] |
| | for tool in tools: |
| | if tool.get("type") == "function": |
| | func = tool.get("function", {}) |
| | name = func.get("name", "") |
| | desc = func.get("description", "") |
| | params = func.get("parameters", {}) |
| | tools_desc_parts.append(f"- {name}: {desc}") |
| | if tools_desc_parts: |
| | tools_description = "\n\nAvailable tools:\n" + "\n".join(tools_desc_parts) |
| | |
| | for msg in messages: |
| | role = msg.get("role", "") |
| | |
| | if role == "developer": |
| | |
| | content = msg.get("content", "") |
| | if tools_description: |
| | content = content + tools_description |
| | converted.append({ |
| | "role": "system", |
| | "content": content |
| | }) |
| | |
| | elif role == "user": |
| | converted.append({ |
| | "role": "user", |
| | "content": msg.get("content", "") |
| | }) |
| | |
| | elif role == "assistant": |
| | if "tool_calls" in msg: |
| | |
| | tool_calls_text = convert_tool_calls_to_text(msg["tool_calls"]) |
| | converted.append({ |
| | "role": "assistant", |
| | "content": tool_calls_text |
| | }) |
| | else: |
| | converted.append({ |
| | "role": "assistant", |
| | "content": msg.get("content", "") |
| | }) |
| | |
| | elif role == "tool": |
| | |
| | converted.append({ |
| | "role": "tool", |
| | "content": msg.get("content", "") |
| | }) |
| | |
| | return converted |
| |
|
| |
|
| | def prepare_dataset(input_path: str, output_path: str, format_type: str = "messages"): |
| | """ |
| | Prepare dataset. |
| | |
| | format_type: |
| | - "messages": output {"messages": [...]} |
| | - "text": output {"text": "..."} (flattened text) |
| | """ |
| | print(f"Loading dataset: {input_path}") |
| | |
| | with open(input_path, 'r', encoding='utf-8') as f: |
| | data = json.load(f) |
| | |
| | print(f"Raw samples: {len(data)}") |
| | |
| | prepared_data = [] |
| | |
| | for i, item in enumerate(data): |
| | messages = item.get("messages", []) |
| | tools = item.get("tools", []) |
| | |
| | |
| | converted_messages = convert_messages_for_sft(messages, tools) |
| | |
| | if format_type == "messages": |
| | prepared_data.append({ |
| | "messages": converted_messages |
| | }) |
| | elif format_type == "text": |
| | |
| | text_parts = [] |
| | for msg in converted_messages: |
| | role = msg["role"] |
| | content = msg["content"] |
| | if role == "system": |
| | text_parts.append(f"<start_of_turn>system\n{content}<end_of_turn>") |
| | elif role == "user": |
| | text_parts.append(f"<start_of_turn>user\n{content}<end_of_turn>") |
| | elif role == "assistant": |
| | text_parts.append(f"<start_of_turn>model\n{content}<end_of_turn>") |
| | |
| | prepared_data.append({ |
| | "text": "\n".join(text_parts) |
| | }) |
| | |
| | print(f"Processed samples: {len(prepared_data)}") |
| | |
| | |
| | with open(output_path, 'w', encoding='utf-8') as f: |
| | json.dump(prepared_data, f, ensure_ascii=False, indent=2) |
| | |
| | print(f"Saved to: {output_path}") |
| | |
| | |
| | print("\n" + "=" * 60) |
| | print("Example:") |
| | print("=" * 60) |
| | |
| | if format_type == "messages": |
| | example = prepared_data[0] |
| | for msg in example["messages"]: |
| | print(f"\n[{msg['role']}]") |
| | print(msg["content"][:200] + "..." if len(msg["content"]) > 200 else msg["content"]) |
| | else: |
| | print(prepared_data[0]["text"][:500] + "...") |
| | |
| | return prepared_data |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Dataset preparation") |
| | parser.add_argument("--input", type=str, default=str(DEFAULT_INPUT), help="Input file path") |
| | parser.add_argument("--output", type=str, default=str(DEFAULT_OUTPUT), help="Output file path") |
| | parser.add_argument("--format", type=str, choices=["messages", "text"], default="messages", help="Output format") |
| | |
| | args = parser.parse_args() |
| | |
| | prepare_dataset(args.input, args.output, args.format) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|