|
|
|
|
|
""" |
|
|
FunctionGemma SFT fine-tuning script. |
|
|
|
|
|
Runs TRL SFTTrainer for FunctionGemma with two modes: |
|
|
1) LoRA (recommended): faster, lower memory, less overfit |
|
|
2) Full-parameter: higher cost, maximal capacity |
|
|
|
|
|
Usage: |
|
|
# LoRA (default) |
|
|
python -m src.train \ |
|
|
--model_path /path/to/model \ |
|
|
--dataset_path ./data/training_data.json \ |
|
|
--bf16 |
|
|
|
|
|
# Full-parameter |
|
|
python -m src.train \ |
|
|
--model_path /path/to/model \ |
|
|
--dataset_path ./data/training_data.json \ |
|
|
--no-use-lora \ |
|
|
--bf16 |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import argparse |
|
|
import logging |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
from datasets import Dataset, load_dataset |
|
|
from transformers import ( |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
TrainingArguments, |
|
|
BitsAndBytesConfig, |
|
|
) |
|
|
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training |
|
|
from trl import SFTTrainer, SFTConfig |
|
|
|
|
|
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parent.parent |
|
|
DEFAULT_DATA_PATH = PROJECT_ROOT / "data" / "training_data.json" |
|
|
DEFAULT_OUTPUT_DIR = PROJECT_ROOT / "runs" |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
"""Parse CLI arguments.""" |
|
|
parser = argparse.ArgumentParser(description="FunctionGemma SFT fine-tuning (LoRA / full)") |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--model_path", |
|
|
type=str, |
|
|
default="google/functiongemma-270m-it", |
|
|
help="Model path or HF model id" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--tokenizer_path", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Tokenizer path (defaults to model_path)" |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--dataset_path", |
|
|
type=str, |
|
|
default=str(DEFAULT_DATA_PATH), |
|
|
help="Training dataset path" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--val_split", |
|
|
type=float, |
|
|
default=0.1, |
|
|
help="Validation split ratio" |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--output_dir", |
|
|
type=str, |
|
|
default=str(DEFAULT_OUTPUT_DIR), |
|
|
help="Root output directory" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--run_name", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Run name for logging and saving" |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--use_lora", |
|
|
action="store_true", |
|
|
default=True, |
|
|
help="Enable LoRA (recommended). Add --no-use-lora for full-parameter finetune" |
|
|
) |
|
|
parser.add_argument("--no-use-lora", dest="use_lora", action="store_false", help="Disable LoRA, run full-parameter finetune") |
|
|
|
|
|
|
|
|
parser.add_argument("--lora_r", type=int, default=16, help="LoRA rank") |
|
|
parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha") |
|
|
parser.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout") |
|
|
parser.add_argument( |
|
|
"--target_modules", |
|
|
type=str, |
|
|
nargs="+", |
|
|
default=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], |
|
|
help="Target modules for LoRA" |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument("--num_train_epochs", type=int, default=6, help="Training epochs (official rec: 8)") |
|
|
parser.add_argument("--max_steps", type=int, default=-1, help="Max training steps (-1 to use epochs)") |
|
|
parser.add_argument("--per_device_train_batch_size", type=int, default=4, help="Train batch size per device") |
|
|
parser.add_argument("--per_device_eval_batch_size", type=int, default=2, help="Eval batch size") |
|
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Grad accumulation steps") |
|
|
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate") |
|
|
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay") |
|
|
parser.add_argument("--warmup_ratio", type=float, default=0.0, help="Warmup ratio (constant scheduler usually skips warmup)") |
|
|
parser.add_argument("--max_seq_length", type=int, default=2048, help="Max sequence length (model supports up to 32768)") |
|
|
parser.add_argument("--lr_scheduler_type", type=str, default="constant", help="LR scheduler type (default constant)") |
|
|
|
|
|
|
|
|
parser.add_argument("--bf16", action="store_true", help="Use BF16") |
|
|
parser.add_argument("--fp16", action="store_true", help="Use FP16") |
|
|
parser.add_argument("--use_4bit", action="store_true", help="Enable 4-bit quant (QLoRA)") |
|
|
parser.add_argument("--use_8bit", action="store_true", help="Enable 8-bit quant") |
|
|
parser.add_argument("--use_flash_attention", action="store_true", help="Enable Flash Attention 2") |
|
|
parser.add_argument("--gradient_checkpointing", action="store_true", help="Enable gradient checkpointing") |
|
|
|
|
|
|
|
|
parser.add_argument("--logging_steps", type=int, default=10, help="Log every N steps") |
|
|
parser.add_argument("--save_steps", type=int, default=100, help="Save checkpoint every N steps") |
|
|
parser.add_argument("--eval_steps", type=int, default=100, help="Eval every N steps") |
|
|
parser.add_argument("--save_total_limit", type=int, default=3, help="Max checkpoints to keep") |
|
|
|
|
|
|
|
|
parser.add_argument("--seed", type=int, default=42, help="Random seed") |
|
|
parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Resume from checkpoint") |
|
|
parser.add_argument("--push_to_hub", action="store_true", help="Push to Hugging Face Hub") |
|
|
parser.add_argument("--hub_model_id", type=str, default=None, help="Hub model id") |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def load_and_prepare_dataset(dataset_path: str, val_split: float = 0.1): |
|
|
"""Load and normalize dataset structure for SFT.""" |
|
|
logger.info(f"Loading dataset: {dataset_path}") |
|
|
|
|
|
|
|
|
with open(dataset_path, 'r', encoding='utf-8') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
logger.info(f"Dataset size: {len(data)} samples") |
|
|
|
|
|
|
|
|
|
|
|
processed_data = [] |
|
|
for idx, item in enumerate(data): |
|
|
if 'input' in item and 'messages' in item['input']: |
|
|
|
|
|
messages = json.loads(json.dumps(item['input']['messages'])) |
|
|
|
|
|
|
|
|
for msg in messages: |
|
|
if 'tool_calls' in msg and msg['tool_calls']: |
|
|
for tc in msg['tool_calls']: |
|
|
if 'function' in tc and 'arguments' in tc['function']: |
|
|
args = tc['function']['arguments'] |
|
|
|
|
|
if not isinstance(args, str): |
|
|
tc['function']['arguments'] = json.dumps(args) |
|
|
|
|
|
|
|
|
if 'expected' in item and item['expected']: |
|
|
expected = item['expected'] |
|
|
|
|
|
if messages[-1]['role'] != 'assistant': |
|
|
|
|
|
function_name = expected.get('function_name') |
|
|
arguments = expected.get('arguments') |
|
|
response = expected.get('response', '') |
|
|
|
|
|
if function_name is not None and arguments is not None: |
|
|
|
|
|
arguments_str = json.dumps(arguments) if isinstance(arguments, dict) else str(arguments) |
|
|
|
|
|
assistant_msg = { |
|
|
"role": "assistant", |
|
|
"content": None, |
|
|
"tool_calls": [{ |
|
|
"id": f"call_{hash(function_name + arguments_str) % 1000000}", |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": function_name, |
|
|
"arguments": arguments_str |
|
|
} |
|
|
}] |
|
|
} |
|
|
messages.append(assistant_msg) |
|
|
logger.debug(f"Added assistant tool_calls: {function_name}") |
|
|
elif function_name is None and arguments is None and response: |
|
|
|
|
|
assistant_msg = { |
|
|
"role": "assistant", |
|
|
"content": response |
|
|
} |
|
|
messages.append(assistant_msg) |
|
|
logger.debug(f"Added assistant refusal response: {response[:50]}") |
|
|
else: |
|
|
logger.warning(f"Unknown expected format: {expected}") |
|
|
|
|
|
processed_item = { |
|
|
'messages': messages |
|
|
} |
|
|
|
|
|
|
|
|
if 'tools' in item['input']: |
|
|
processed_item['tools'] = item['input']['tools'] |
|
|
|
|
|
|
|
|
if 'id' in item: |
|
|
processed_item['id'] = item['id'] |
|
|
|
|
|
|
|
|
for msg in processed_item['messages']: |
|
|
if 'tool_calls' in msg and msg['tool_calls']: |
|
|
for tc in msg['tool_calls']: |
|
|
if 'function' in tc and 'arguments' in tc['function']: |
|
|
if not isinstance(tc['function']['arguments'], str): |
|
|
logger.error(f"Sample {idx} arguments not string: {type(tc['function']['arguments'])}") |
|
|
tc['function']['arguments'] = json.dumps(tc['function']['arguments']) |
|
|
|
|
|
processed_data.append(processed_item) |
|
|
|
|
|
elif 'messages' in item: |
|
|
|
|
|
messages = json.loads(json.dumps(item['messages'])) |
|
|
for msg in messages: |
|
|
if 'tool_calls' in msg and msg['tool_calls']: |
|
|
for tc in msg['tool_calls']: |
|
|
if 'function' in tc and 'arguments' in tc['function']: |
|
|
if not isinstance(tc['function']['arguments'], str): |
|
|
tc['function']['arguments'] = json.dumps(tc['function']['arguments']) |
|
|
item_copy = dict(item) |
|
|
item_copy['messages'] = messages |
|
|
processed_data.append(item_copy) |
|
|
else: |
|
|
logger.warning(f"Skip malformed item: {item.get('id', 'unknown')}") |
|
|
|
|
|
logger.info(f"Processed dataset size: {len(processed_data)}") |
|
|
|
|
|
|
|
|
tool_calls_count = 0 |
|
|
for item in processed_data: |
|
|
for msg in item['messages']: |
|
|
if 'tool_calls' in msg and msg['tool_calls']: |
|
|
tool_calls_count += 1 |
|
|
for tc in msg['tool_calls']: |
|
|
if 'function' in tc and 'arguments' in tc['function']: |
|
|
if not isinstance(tc['function']['arguments'], str): |
|
|
logger.error(f"Found non-string arguments: {type(tc['function']['arguments'])}") |
|
|
logger.info(f"Messages containing tool_calls: {tool_calls_count}") |
|
|
|
|
|
|
|
|
dataset = Dataset.from_list(processed_data) |
|
|
|
|
|
|
|
|
if val_split > 0: |
|
|
dataset = dataset.train_test_split(test_size=val_split, seed=42) |
|
|
train_dataset = dataset['train'] |
|
|
eval_dataset = dataset['test'] |
|
|
logger.info(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}") |
|
|
else: |
|
|
train_dataset = dataset |
|
|
eval_dataset = None |
|
|
logger.info(f"Train: {len(train_dataset)}, no eval split") |
|
|
|
|
|
return train_dataset, eval_dataset |
|
|
|
|
|
|
|
|
def get_quantization_config(use_4bit: bool, use_8bit: bool): |
|
|
"""Build quantization config if requested.""" |
|
|
if use_4bit: |
|
|
logger.info("Using 4-bit quantization (QLoRA)") |
|
|
return BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
) |
|
|
elif use_8bit: |
|
|
logger.info("Using 8-bit quantization") |
|
|
return BitsAndBytesConfig( |
|
|
load_in_8bit=True, |
|
|
) |
|
|
return None |
|
|
|
|
|
|
|
|
def load_model_and_tokenizer(args): |
|
|
"""Load model and tokenizer.""" |
|
|
logger.info(f"Loading model: {args.model_path}") |
|
|
|
|
|
tokenizer_path = args.tokenizer_path or args.model_path |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
tokenizer_path, |
|
|
trust_remote_code=True, |
|
|
padding_side="right", |
|
|
) |
|
|
|
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
|
|
|
|
|
|
quantization_config = get_quantization_config(args.use_4bit, args.use_8bit) |
|
|
|
|
|
|
|
|
model_kwargs = { |
|
|
"trust_remote_code": True, |
|
|
"device_map": "auto", |
|
|
} |
|
|
|
|
|
if quantization_config: |
|
|
model_kwargs["quantization_config"] = quantization_config |
|
|
|
|
|
|
|
|
if args.bf16 and not (args.use_4bit or args.use_8bit): |
|
|
model_kwargs["torch_dtype"] = torch.bfloat16 |
|
|
elif args.fp16 and not (args.use_4bit or args.use_8bit): |
|
|
model_kwargs["torch_dtype"] = torch.float16 |
|
|
|
|
|
|
|
|
if args.use_flash_attention: |
|
|
model_kwargs["attn_implementation"] = "flash_attention_2" |
|
|
logger.info("Using Flash Attention 2") |
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
args.model_path, |
|
|
**model_kwargs |
|
|
) |
|
|
|
|
|
|
|
|
if args.use_4bit or args.use_8bit: |
|
|
model = prepare_model_for_kbit_training(model) |
|
|
|
|
|
|
|
|
if args.gradient_checkpointing: |
|
|
model.gradient_checkpointing_enable() |
|
|
logger.info("Enabled gradient checkpointing") |
|
|
|
|
|
logger.info(f"Model parameters: {model.num_parameters():,}") |
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
def get_lora_config(args): |
|
|
"""Build LoRA config.""" |
|
|
logger.info(f"LoRA config: r={args.lora_r}, alpha={args.lora_alpha}, dropout={args.lora_dropout}") |
|
|
logger.info(f"Target modules: {args.target_modules}") |
|
|
|
|
|
return LoraConfig( |
|
|
r=args.lora_r, |
|
|
lora_alpha=args.lora_alpha, |
|
|
lora_dropout=args.lora_dropout, |
|
|
target_modules=args.target_modules, |
|
|
bias="none", |
|
|
task_type=TaskType.CAUSAL_LM, |
|
|
) |
|
|
|
|
|
|
|
|
def formatting_func(example): |
|
|
""" |
|
|
Format function: pass data through for SFTTrainer. |
|
|
|
|
|
Dataset format: |
|
|
{ |
|
|
"messages": [ |
|
|
{"role": "developer", "content": "..."}, |
|
|
{"role": "user", "content": "..."}, |
|
|
{"role": "assistant", "tool_calls": [...]} or {"role": "assistant", "content": "..."} |
|
|
], |
|
|
"tools": [...] |
|
|
} |
|
|
""" |
|
|
|
|
|
return example |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
|
|
|
|
|
|
if args.run_name is None: |
|
|
args.run_name = f"functiongemma-lora-{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
|
|
|
|
|
|
|
|
output_dir = os.path.join(args.output_dir, args.run_name) |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
logger.info("=" * 60) |
|
|
logger.info("FunctionGemma SFT LoRA training") |
|
|
logger.info("=" * 60) |
|
|
logger.info(f"Output dir: {output_dir}") |
|
|
|
|
|
|
|
|
config_path = os.path.join(output_dir, "training_config.json") |
|
|
with open(config_path, 'w') as f: |
|
|
json.dump(vars(args), f, indent=2) |
|
|
logger.info(f"Config saved to: {config_path}") |
|
|
|
|
|
|
|
|
train_dataset, eval_dataset = load_and_prepare_dataset( |
|
|
args.dataset_path, |
|
|
args.val_split |
|
|
) |
|
|
|
|
|
|
|
|
model, tokenizer = load_model_and_tokenizer(args) |
|
|
|
|
|
|
|
|
if args.use_lora: |
|
|
logger.info("=" * 60) |
|
|
logger.info("LoRA fine-tuning mode") |
|
|
logger.info("=" * 60) |
|
|
lora_config = get_lora_config(args) |
|
|
else: |
|
|
logger.info("=" * 60) |
|
|
logger.info("Full-parameter fine-tuning mode") |
|
|
logger.info("Warning: full fine-tuning needs more memory and time!") |
|
|
logger.info("=" * 60) |
|
|
lora_config = None |
|
|
|
|
|
|
|
|
training_args = SFTConfig( |
|
|
output_dir=output_dir, |
|
|
run_name=args.run_name, |
|
|
|
|
|
|
|
|
max_length=args.max_seq_length, |
|
|
packing=False, |
|
|
|
|
|
|
|
|
num_train_epochs=args.num_train_epochs, |
|
|
max_steps=args.max_steps, |
|
|
per_device_train_batch_size=args.per_device_train_batch_size, |
|
|
per_device_eval_batch_size=args.per_device_eval_batch_size, |
|
|
gradient_accumulation_steps=args.gradient_accumulation_steps, |
|
|
|
|
|
|
|
|
learning_rate=args.learning_rate, |
|
|
weight_decay=args.weight_decay, |
|
|
warmup_ratio=args.warmup_ratio, |
|
|
lr_scheduler_type=args.lr_scheduler_type, |
|
|
optim="adamw_torch_fused", |
|
|
|
|
|
|
|
|
bf16=args.bf16, |
|
|
fp16=args.fp16, |
|
|
|
|
|
|
|
|
logging_steps=args.logging_steps, |
|
|
save_steps=args.save_steps, |
|
|
eval_steps=args.eval_steps if eval_dataset else None, |
|
|
eval_strategy="steps" if eval_dataset else "no", |
|
|
save_total_limit=args.save_total_limit, |
|
|
load_best_model_at_end=True if eval_dataset else False, |
|
|
|
|
|
|
|
|
seed=args.seed, |
|
|
report_to=["tensorboard"], |
|
|
|
|
|
|
|
|
push_to_hub=args.push_to_hub, |
|
|
hub_model_id=args.hub_model_id, |
|
|
|
|
|
|
|
|
gradient_checkpointing=args.gradient_checkpointing, |
|
|
gradient_checkpointing_kwargs={"use_reentrant": False} if args.gradient_checkpointing else None, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
trainer = SFTTrainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=train_dataset, |
|
|
eval_dataset=eval_dataset, |
|
|
processing_class=tokenizer, |
|
|
peft_config=lora_config, |
|
|
) |
|
|
|
|
|
|
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
trainable_percentage = 100 * trainable_params / total_params if total_params > 0 else 0 |
|
|
|
|
|
logger.info("=" * 60) |
|
|
logger.info("Model parameter stats:") |
|
|
logger.info(f" Total params: {total_params:,}") |
|
|
logger.info(f" Trainable params: {trainable_params:,}") |
|
|
logger.info(f" Trainable ratio: {trainable_percentage:.2f}%") |
|
|
logger.info(f" Mode: {'LoRA' if args.use_lora else 'Full fine-tune'}") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
|
|
|
logger.info("Start training...") |
|
|
|
|
|
if args.resume_from_checkpoint: |
|
|
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) |
|
|
else: |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
logger.info("Saving final model...") |
|
|
final_model_path = os.path.join(output_dir, "final_model") |
|
|
trainer.save_model(final_model_path) |
|
|
tokenizer.save_pretrained(final_model_path) |
|
|
|
|
|
logger.info("=" * 60) |
|
|
logger.info("Training done.") |
|
|
logger.info(f"Model saved at: {final_model_path}") |
|
|
|
|
|
if args.use_lora: |
|
|
|
|
|
lora_path = os.path.join(output_dir, "lora_adapter") |
|
|
model.save_pretrained(lora_path) |
|
|
tokenizer.save_pretrained(lora_path) |
|
|
logger.info(f"LoRA adapter saved to: {lora_path}") |
|
|
logger.info("") |
|
|
logger.info("Usage:") |
|
|
logger.info(f" 1. LoRA adapter: {lora_path}") |
|
|
logger.info(f" 2. Merge adapters with your base model before inference") |
|
|
else: |
|
|
|
|
|
logger.info("") |
|
|
logger.info("Usage:") |
|
|
logger.info(f" Use model directly from: {final_model_path}") |
|
|
|
|
|
logger.info("=" * 60) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|