DMind-3-nano / src /train.py
yuzhe's picture
Upload 13 files
6f09d40 verified
#!/usr/bin/env python3
"""
FunctionGemma SFT fine-tuning script.
Runs TRL SFTTrainer for FunctionGemma with two modes:
1) LoRA (recommended): faster, lower memory, less overfit
2) Full-parameter: higher cost, maximal capacity
Usage:
# LoRA (default)
python -m src.train \
--model_path /path/to/model \
--dataset_path ./data/training_data.json \
--bf16
# Full-parameter
python -m src.train \
--model_path /path/to/model \
--dataset_path ./data/training_data.json \
--no-use-lora \
--bf16
"""
import os
import json
import argparse
import logging
from datetime import datetime
from pathlib import Path
from typing import Optional
import torch
from datasets import Dataset, load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
from trl import SFTTrainer, SFTConfig
# Paths and logging
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_DATA_PATH = PROJECT_ROOT / "data" / "training_data.json"
DEFAULT_OUTPUT_DIR = PROJECT_ROOT / "runs"
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def parse_args():
"""Parse CLI arguments."""
parser = argparse.ArgumentParser(description="FunctionGemma SFT fine-tuning (LoRA / full)")
# Model
parser.add_argument(
"--model_path",
type=str,
default="google/functiongemma-270m-it",
help="Model path or HF model id"
)
parser.add_argument(
"--tokenizer_path",
type=str,
default=None,
help="Tokenizer path (defaults to model_path)"
)
# Dataset
parser.add_argument(
"--dataset_path",
type=str,
default=str(DEFAULT_DATA_PATH),
help="Training dataset path"
)
parser.add_argument(
"--val_split",
type=float,
default=0.1,
help="Validation split ratio"
)
# Output
parser.add_argument(
"--output_dir",
type=str,
default=str(DEFAULT_OUTPUT_DIR),
help="Root output directory"
)
parser.add_argument(
"--run_name",
type=str,
default=None,
help="Run name for logging and saving"
)
# Fine-tuning mode
parser.add_argument(
"--use_lora",
action="store_true",
default=True,
help="Enable LoRA (recommended). Add --no-use-lora for full-parameter finetune"
)
parser.add_argument("--no-use-lora", dest="use_lora", action="store_false", help="Disable LoRA, run full-parameter finetune")
# LoRA (only when use_lora=True)
parser.add_argument("--lora_r", type=int, default=16, help="LoRA rank")
parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha")
parser.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout")
parser.add_argument(
"--target_modules",
type=str,
nargs="+",
default=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
help="Target modules for LoRA"
)
# Training (aligned with FunctionGemma guidance)
parser.add_argument("--num_train_epochs", type=int, default=6, help="Training epochs (official rec: 8)")
parser.add_argument("--max_steps", type=int, default=-1, help="Max training steps (-1 to use epochs)")
parser.add_argument("--per_device_train_batch_size", type=int, default=4, help="Train batch size per device")
parser.add_argument("--per_device_eval_batch_size", type=int, default=2, help="Eval batch size")
parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Grad accumulation steps")
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay")
parser.add_argument("--warmup_ratio", type=float, default=0.0, help="Warmup ratio (constant scheduler usually skips warmup)")
parser.add_argument("--max_seq_length", type=int, default=2048, help="Max sequence length (model supports up to 32768)")
parser.add_argument("--lr_scheduler_type", type=str, default="constant", help="LR scheduler type (default constant)")
# Precision & optimization
parser.add_argument("--bf16", action="store_true", help="Use BF16")
parser.add_argument("--fp16", action="store_true", help="Use FP16")
parser.add_argument("--use_4bit", action="store_true", help="Enable 4-bit quant (QLoRA)")
parser.add_argument("--use_8bit", action="store_true", help="Enable 8-bit quant")
parser.add_argument("--use_flash_attention", action="store_true", help="Enable Flash Attention 2")
parser.add_argument("--gradient_checkpointing", action="store_true", help="Enable gradient checkpointing")
# Logging & saving
parser.add_argument("--logging_steps", type=int, default=10, help="Log every N steps")
parser.add_argument("--save_steps", type=int, default=100, help="Save checkpoint every N steps")
parser.add_argument("--eval_steps", type=int, default=100, help="Eval every N steps")
parser.add_argument("--save_total_limit", type=int, default=3, help="Max checkpoints to keep")
# Misc
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Resume from checkpoint")
parser.add_argument("--push_to_hub", action="store_true", help="Push to Hugging Face Hub")
parser.add_argument("--hub_model_id", type=str, default=None, help="Hub model id")
return parser.parse_args()
def load_and_prepare_dataset(dataset_path: str, val_split: float = 0.1):
"""Load and normalize dataset structure for SFT."""
logger.info(f"Loading dataset: {dataset_path}")
# Load JSON dataset
with open(dataset_path, 'r', encoding='utf-8') as f:
data = json.load(f)
logger.info(f"Dataset size: {len(data)} samples")
# Normalize nested structures:
# if an item has input.messages/tools, lift them to top-level
processed_data = []
for idx, item in enumerate(data):
if 'input' in item and 'messages' in item['input']:
# Deep copy messages to avoid mutating original
messages = json.loads(json.dumps(item['input']['messages']))
# Fix tool_calls formatting if present
for msg in messages:
if 'tool_calls' in msg and msg['tool_calls']:
for tc in msg['tool_calls']:
if 'function' in tc and 'arguments' in tc['function']:
args = tc['function']['arguments']
# ensure arguments is a string
if not isinstance(args, str):
tc['function']['arguments'] = json.dumps(args)
# Convert expected field into assistant response if present
if 'expected' in item and item['expected']:
expected = item['expected']
# If last message is not assistant, append one
if messages[-1]['role'] != 'assistant':
# Decide between function call or refusal
function_name = expected.get('function_name')
arguments = expected.get('arguments')
response = expected.get('response', '')
if function_name is not None and arguments is not None:
# Case 1: function call -> add tool_calls
arguments_str = json.dumps(arguments) if isinstance(arguments, dict) else str(arguments)
assistant_msg = {
"role": "assistant",
"content": None,
"tool_calls": [{
"id": f"call_{hash(function_name + arguments_str) % 1000000}", # generate unique id
"type": "function",
"function": {
"name": function_name,
"arguments": arguments_str
}
}]
}
messages.append(assistant_msg)
logger.debug(f"Added assistant tool_calls: {function_name}")
elif function_name is None and arguments is None and response:
# Case 2: refusal -> plain text response
assistant_msg = {
"role": "assistant",
"content": response
}
messages.append(assistant_msg)
logger.debug(f"Added assistant refusal response: {response[:50]}")
else:
logger.warning(f"Unknown expected format: {expected}")
processed_item = {
'messages': messages
}
# include tools if present
if 'tools' in item['input']:
processed_item['tools'] = item['input']['tools']
# preserve id
if 'id' in item:
processed_item['id'] = item['id']
# Final check: tool_calls arguments must be strings
for msg in processed_item['messages']:
if 'tool_calls' in msg and msg['tool_calls']:
for tc in msg['tool_calls']:
if 'function' in tc and 'arguments' in tc['function']:
if not isinstance(tc['function']['arguments'], str):
logger.error(f"Sample {idx} arguments not string: {type(tc['function']['arguments'])}")
tc['function']['arguments'] = json.dumps(tc['function']['arguments'])
processed_data.append(processed_item)
elif 'messages' in item:
# Already proper format, just normalize tool_calls
messages = json.loads(json.dumps(item['messages']))
for msg in messages:
if 'tool_calls' in msg and msg['tool_calls']:
for tc in msg['tool_calls']:
if 'function' in tc and 'arguments' in tc['function']:
if not isinstance(tc['function']['arguments'], str):
tc['function']['arguments'] = json.dumps(tc['function']['arguments'])
item_copy = dict(item)
item_copy['messages'] = messages
processed_data.append(item_copy)
else:
logger.warning(f"Skip malformed item: {item.get('id', 'unknown')}")
logger.info(f"Processed dataset size: {len(processed_data)}")
# Validate format
tool_calls_count = 0
for item in processed_data:
for msg in item['messages']:
if 'tool_calls' in msg and msg['tool_calls']:
tool_calls_count += 1
for tc in msg['tool_calls']:
if 'function' in tc and 'arguments' in tc['function']:
if not isinstance(tc['function']['arguments'], str):
logger.error(f"Found non-string arguments: {type(tc['function']['arguments'])}")
logger.info(f"Messages containing tool_calls: {tool_calls_count}")
# Convert to Hugging Face Dataset
dataset = Dataset.from_list(processed_data)
# Split train/val
if val_split > 0:
dataset = dataset.train_test_split(test_size=val_split, seed=42)
train_dataset = dataset['train']
eval_dataset = dataset['test']
logger.info(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
else:
train_dataset = dataset
eval_dataset = None
logger.info(f"Train: {len(train_dataset)}, no eval split")
return train_dataset, eval_dataset
def get_quantization_config(use_4bit: bool, use_8bit: bool):
"""Build quantization config if requested."""
if use_4bit:
logger.info("Using 4-bit quantization (QLoRA)")
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
elif use_8bit:
logger.info("Using 8-bit quantization")
return BitsAndBytesConfig(
load_in_8bit=True,
)
return None
def load_model_and_tokenizer(args):
"""Load model and tokenizer."""
logger.info(f"Loading model: {args.model_path}")
tokenizer_path = args.tokenizer_path or args.model_path
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path,
trust_remote_code=True,
padding_side="right",
)
# Ensure pad token exists
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
# Quantization config
quantization_config = get_quantization_config(args.use_4bit, args.use_8bit)
# Model kwargs
model_kwargs = {
"trust_remote_code": True,
"device_map": "auto",
}
if quantization_config:
model_kwargs["quantization_config"] = quantization_config
# Precision
if args.bf16 and not (args.use_4bit or args.use_8bit):
model_kwargs["torch_dtype"] = torch.bfloat16
elif args.fp16 and not (args.use_4bit or args.use_8bit):
model_kwargs["torch_dtype"] = torch.float16
# Flash Attention
if args.use_flash_attention:
model_kwargs["attn_implementation"] = "flash_attention_2"
logger.info("Using Flash Attention 2")
# Load model
model = AutoModelForCausalLM.from_pretrained(
args.model_path,
**model_kwargs
)
# Prepare for k-bit training when quantized
if args.use_4bit or args.use_8bit:
model = prepare_model_for_kbit_training(model)
# Gradient checkpointing
if args.gradient_checkpointing:
model.gradient_checkpointing_enable()
logger.info("Enabled gradient checkpointing")
logger.info(f"Model parameters: {model.num_parameters():,}")
return model, tokenizer
def get_lora_config(args):
"""Build LoRA config."""
logger.info(f"LoRA config: r={args.lora_r}, alpha={args.lora_alpha}, dropout={args.lora_dropout}")
logger.info(f"Target modules: {args.target_modules}")
return LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=args.target_modules,
bias="none",
task_type=TaskType.CAUSAL_LM,
)
def formatting_func(example):
"""
Format function: pass data through for SFTTrainer.
Dataset format:
{
"messages": [
{"role": "developer", "content": "..."},
{"role": "user", "content": "..."},
{"role": "assistant", "tool_calls": [...]} or {"role": "assistant", "content": "..."}
],
"tools": [...]
}
"""
# Return as-is; SFTTrainer applies chat template
return example
def main():
args = parse_args()
# Set run name
if args.run_name is None:
args.run_name = f"functiongemma-lora-{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# Create output directory
output_dir = os.path.join(args.output_dir, args.run_name)
os.makedirs(output_dir, exist_ok=True)
logger.info("=" * 60)
logger.info("FunctionGemma SFT LoRA training")
logger.info("=" * 60)
logger.info(f"Output dir: {output_dir}")
# Save config
config_path = os.path.join(output_dir, "training_config.json")
with open(config_path, 'w') as f:
json.dump(vars(args), f, indent=2)
logger.info(f"Config saved to: {config_path}")
# Load dataset
train_dataset, eval_dataset = load_and_prepare_dataset(
args.dataset_path,
args.val_split
)
# Load model + tokenizer
model, tokenizer = load_model_and_tokenizer(args)
# Build LoRA config if enabled
if args.use_lora:
logger.info("=" * 60)
logger.info("LoRA fine-tuning mode")
logger.info("=" * 60)
lora_config = get_lora_config(args)
else:
logger.info("=" * 60)
logger.info("Full-parameter fine-tuning mode")
logger.info("Warning: full fine-tuning needs more memory and time!")
logger.info("=" * 60)
lora_config = None
# SFTTrainer config
training_args = SFTConfig(
output_dir=output_dir,
run_name=args.run_name,
# Sequence length / packing
max_length=args.max_seq_length,
packing=False,
# Training
num_train_epochs=args.num_train_epochs,
max_steps=args.max_steps,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
# Optimizer
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
warmup_ratio=args.warmup_ratio,
lr_scheduler_type=args.lr_scheduler_type,
optim="adamw_torch_fused",
# Precision
bf16=args.bf16,
fp16=args.fp16,
# Logging / saving
logging_steps=args.logging_steps,
save_steps=args.save_steps,
eval_steps=args.eval_steps if eval_dataset else None,
eval_strategy="steps" if eval_dataset else "no",
save_total_limit=args.save_total_limit,
load_best_model_at_end=True if eval_dataset else False,
# Misc
seed=args.seed,
report_to=["tensorboard"],
# Hub
push_to_hub=args.push_to_hub,
hub_model_id=args.hub_model_id,
# Gradient checkpointing
gradient_checkpointing=args.gradient_checkpointing,
gradient_checkpointing_kwargs={"use_reentrant": False} if args.gradient_checkpointing else None,
)
# Create SFTTrainer
# Dataset should include 'messages' and 'tools'; SFTTrainer applies chat template automatically
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer, # newer TRL uses processing_class instead of tokenizer
peft_config=lora_config,
)
# Parameter stats
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
trainable_percentage = 100 * trainable_params / total_params if total_params > 0 else 0
logger.info("=" * 60)
logger.info("Model parameter stats:")
logger.info(f" Total params: {total_params:,}")
logger.info(f" Trainable params: {trainable_params:,}")
logger.info(f" Trainable ratio: {trainable_percentage:.2f}%")
logger.info(f" Mode: {'LoRA' if args.use_lora else 'Full fine-tune'}")
logger.info("=" * 60)
# Train
logger.info("Start training...")
if args.resume_from_checkpoint:
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
else:
trainer.train()
# Save final model
logger.info("Saving final model...")
final_model_path = os.path.join(output_dir, "final_model")
trainer.save_model(final_model_path)
tokenizer.save_pretrained(final_model_path)
logger.info("=" * 60)
logger.info("Training done.")
logger.info(f"Model saved at: {final_model_path}")
if args.use_lora:
# LoRA: also save adapter
lora_path = os.path.join(output_dir, "lora_adapter")
model.save_pretrained(lora_path)
tokenizer.save_pretrained(lora_path)
logger.info(f"LoRA adapter saved to: {lora_path}")
logger.info("")
logger.info("Usage:")
logger.info(f" 1. LoRA adapter: {lora_path}")
logger.info(f" 2. Merge adapters with your base model before inference")
else:
# Full fine-tune: final_model is ready to use
logger.info("")
logger.info("Usage:")
logger.info(f" Use model directly from: {final_model_path}")
logger.info("=" * 60)
if __name__ == "__main__":
main()