#!/usr/bin/env python3 """ FunctionGemma SFT fine-tuning script. Runs TRL SFTTrainer for FunctionGemma with two modes: 1) LoRA (recommended): faster, lower memory, less overfit 2) Full-parameter: higher cost, maximal capacity Usage: # LoRA (default) python -m src.train \ --model_path /path/to/model \ --dataset_path ./data/training_data.json \ --bf16 # Full-parameter python -m src.train \ --model_path /path/to/model \ --dataset_path ./data/training_data.json \ --no-use-lora \ --bf16 """ import os import json import argparse import logging from datetime import datetime from pathlib import Path from typing import Optional import torch from datasets import Dataset, load_dataset from transformers import ( AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig, ) from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training from trl import SFTTrainer, SFTConfig # Paths and logging PROJECT_ROOT = Path(__file__).resolve().parent.parent DEFAULT_DATA_PATH = PROJECT_ROOT / "data" / "training_data.json" DEFAULT_OUTPUT_DIR = PROJECT_ROOT / "runs" logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) def parse_args(): """Parse CLI arguments.""" parser = argparse.ArgumentParser(description="FunctionGemma SFT fine-tuning (LoRA / full)") # Model parser.add_argument( "--model_path", type=str, default="google/functiongemma-270m-it", help="Model path or HF model id" ) parser.add_argument( "--tokenizer_path", type=str, default=None, help="Tokenizer path (defaults to model_path)" ) # Dataset parser.add_argument( "--dataset_path", type=str, default=str(DEFAULT_DATA_PATH), help="Training dataset path" ) parser.add_argument( "--val_split", type=float, default=0.1, help="Validation split ratio" ) # Output parser.add_argument( "--output_dir", type=str, default=str(DEFAULT_OUTPUT_DIR), help="Root output directory" ) parser.add_argument( "--run_name", type=str, default=None, help="Run name for logging and saving" ) # Fine-tuning mode parser.add_argument( "--use_lora", action="store_true", default=True, help="Enable LoRA (recommended). Add --no-use-lora for full-parameter finetune" ) parser.add_argument("--no-use-lora", dest="use_lora", action="store_false", help="Disable LoRA, run full-parameter finetune") # LoRA (only when use_lora=True) parser.add_argument("--lora_r", type=int, default=16, help="LoRA rank") parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha") parser.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout") parser.add_argument( "--target_modules", type=str, nargs="+", default=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], help="Target modules for LoRA" ) # Training (aligned with FunctionGemma guidance) parser.add_argument("--num_train_epochs", type=int, default=6, help="Training epochs (official rec: 8)") parser.add_argument("--max_steps", type=int, default=-1, help="Max training steps (-1 to use epochs)") parser.add_argument("--per_device_train_batch_size", type=int, default=4, help="Train batch size per device") parser.add_argument("--per_device_eval_batch_size", type=int, default=2, help="Eval batch size") parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Grad accumulation steps") parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate") parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay") parser.add_argument("--warmup_ratio", type=float, default=0.0, help="Warmup ratio (constant scheduler usually skips warmup)") parser.add_argument("--max_seq_length", type=int, default=2048, help="Max sequence length (model supports up to 32768)") parser.add_argument("--lr_scheduler_type", type=str, default="constant", help="LR scheduler type (default constant)") # Precision & optimization parser.add_argument("--bf16", action="store_true", help="Use BF16") parser.add_argument("--fp16", action="store_true", help="Use FP16") parser.add_argument("--use_4bit", action="store_true", help="Enable 4-bit quant (QLoRA)") parser.add_argument("--use_8bit", action="store_true", help="Enable 8-bit quant") parser.add_argument("--use_flash_attention", action="store_true", help="Enable Flash Attention 2") parser.add_argument("--gradient_checkpointing", action="store_true", help="Enable gradient checkpointing") # Logging & saving parser.add_argument("--logging_steps", type=int, default=10, help="Log every N steps") parser.add_argument("--save_steps", type=int, default=100, help="Save checkpoint every N steps") parser.add_argument("--eval_steps", type=int, default=100, help="Eval every N steps") parser.add_argument("--save_total_limit", type=int, default=3, help="Max checkpoints to keep") # Misc parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Resume from checkpoint") parser.add_argument("--push_to_hub", action="store_true", help="Push to Hugging Face Hub") parser.add_argument("--hub_model_id", type=str, default=None, help="Hub model id") return parser.parse_args() def load_and_prepare_dataset(dataset_path: str, val_split: float = 0.1): """Load and normalize dataset structure for SFT.""" logger.info(f"Loading dataset: {dataset_path}") # Load JSON dataset with open(dataset_path, 'r', encoding='utf-8') as f: data = json.load(f) logger.info(f"Dataset size: {len(data)} samples") # Normalize nested structures: # if an item has input.messages/tools, lift them to top-level processed_data = [] for idx, item in enumerate(data): if 'input' in item and 'messages' in item['input']: # Deep copy messages to avoid mutating original messages = json.loads(json.dumps(item['input']['messages'])) # Fix tool_calls formatting if present for msg in messages: if 'tool_calls' in msg and msg['tool_calls']: for tc in msg['tool_calls']: if 'function' in tc and 'arguments' in tc['function']: args = tc['function']['arguments'] # ensure arguments is a string if not isinstance(args, str): tc['function']['arguments'] = json.dumps(args) # Convert expected field into assistant response if present if 'expected' in item and item['expected']: expected = item['expected'] # If last message is not assistant, append one if messages[-1]['role'] != 'assistant': # Decide between function call or refusal function_name = expected.get('function_name') arguments = expected.get('arguments') response = expected.get('response', '') if function_name is not None and arguments is not None: # Case 1: function call -> add tool_calls arguments_str = json.dumps(arguments) if isinstance(arguments, dict) else str(arguments) assistant_msg = { "role": "assistant", "content": None, "tool_calls": [{ "id": f"call_{hash(function_name + arguments_str) % 1000000}", # generate unique id "type": "function", "function": { "name": function_name, "arguments": arguments_str } }] } messages.append(assistant_msg) logger.debug(f"Added assistant tool_calls: {function_name}") elif function_name is None and arguments is None and response: # Case 2: refusal -> plain text response assistant_msg = { "role": "assistant", "content": response } messages.append(assistant_msg) logger.debug(f"Added assistant refusal response: {response[:50]}") else: logger.warning(f"Unknown expected format: {expected}") processed_item = { 'messages': messages } # include tools if present if 'tools' in item['input']: processed_item['tools'] = item['input']['tools'] # preserve id if 'id' in item: processed_item['id'] = item['id'] # Final check: tool_calls arguments must be strings for msg in processed_item['messages']: if 'tool_calls' in msg and msg['tool_calls']: for tc in msg['tool_calls']: if 'function' in tc and 'arguments' in tc['function']: if not isinstance(tc['function']['arguments'], str): logger.error(f"Sample {idx} arguments not string: {type(tc['function']['arguments'])}") tc['function']['arguments'] = json.dumps(tc['function']['arguments']) processed_data.append(processed_item) elif 'messages' in item: # Already proper format, just normalize tool_calls messages = json.loads(json.dumps(item['messages'])) for msg in messages: if 'tool_calls' in msg and msg['tool_calls']: for tc in msg['tool_calls']: if 'function' in tc and 'arguments' in tc['function']: if not isinstance(tc['function']['arguments'], str): tc['function']['arguments'] = json.dumps(tc['function']['arguments']) item_copy = dict(item) item_copy['messages'] = messages processed_data.append(item_copy) else: logger.warning(f"Skip malformed item: {item.get('id', 'unknown')}") logger.info(f"Processed dataset size: {len(processed_data)}") # Validate format tool_calls_count = 0 for item in processed_data: for msg in item['messages']: if 'tool_calls' in msg and msg['tool_calls']: tool_calls_count += 1 for tc in msg['tool_calls']: if 'function' in tc and 'arguments' in tc['function']: if not isinstance(tc['function']['arguments'], str): logger.error(f"Found non-string arguments: {type(tc['function']['arguments'])}") logger.info(f"Messages containing tool_calls: {tool_calls_count}") # Convert to Hugging Face Dataset dataset = Dataset.from_list(processed_data) # Split train/val if val_split > 0: dataset = dataset.train_test_split(test_size=val_split, seed=42) train_dataset = dataset['train'] eval_dataset = dataset['test'] logger.info(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}") else: train_dataset = dataset eval_dataset = None logger.info(f"Train: {len(train_dataset)}, no eval split") return train_dataset, eval_dataset def get_quantization_config(use_4bit: bool, use_8bit: bool): """Build quantization config if requested.""" if use_4bit: logger.info("Using 4-bit quantization (QLoRA)") return BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) elif use_8bit: logger.info("Using 8-bit quantization") return BitsAndBytesConfig( load_in_8bit=True, ) return None def load_model_and_tokenizer(args): """Load model and tokenizer.""" logger.info(f"Loading model: {args.model_path}") tokenizer_path = args.tokenizer_path or args.model_path # Load tokenizer tokenizer = AutoTokenizer.from_pretrained( tokenizer_path, trust_remote_code=True, padding_side="right", ) # Ensure pad token exists if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id # Quantization config quantization_config = get_quantization_config(args.use_4bit, args.use_8bit) # Model kwargs model_kwargs = { "trust_remote_code": True, "device_map": "auto", } if quantization_config: model_kwargs["quantization_config"] = quantization_config # Precision if args.bf16 and not (args.use_4bit or args.use_8bit): model_kwargs["torch_dtype"] = torch.bfloat16 elif args.fp16 and not (args.use_4bit or args.use_8bit): model_kwargs["torch_dtype"] = torch.float16 # Flash Attention if args.use_flash_attention: model_kwargs["attn_implementation"] = "flash_attention_2" logger.info("Using Flash Attention 2") # Load model model = AutoModelForCausalLM.from_pretrained( args.model_path, **model_kwargs ) # Prepare for k-bit training when quantized if args.use_4bit or args.use_8bit: model = prepare_model_for_kbit_training(model) # Gradient checkpointing if args.gradient_checkpointing: model.gradient_checkpointing_enable() logger.info("Enabled gradient checkpointing") logger.info(f"Model parameters: {model.num_parameters():,}") return model, tokenizer def get_lora_config(args): """Build LoRA config.""" logger.info(f"LoRA config: r={args.lora_r}, alpha={args.lora_alpha}, dropout={args.lora_dropout}") logger.info(f"Target modules: {args.target_modules}") return LoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, target_modules=args.target_modules, bias="none", task_type=TaskType.CAUSAL_LM, ) def formatting_func(example): """ Format function: pass data through for SFTTrainer. Dataset format: { "messages": [ {"role": "developer", "content": "..."}, {"role": "user", "content": "..."}, {"role": "assistant", "tool_calls": [...]} or {"role": "assistant", "content": "..."} ], "tools": [...] } """ # Return as-is; SFTTrainer applies chat template return example def main(): args = parse_args() # Set run name if args.run_name is None: args.run_name = f"functiongemma-lora-{datetime.now().strftime('%Y%m%d_%H%M%S')}" # Create output directory output_dir = os.path.join(args.output_dir, args.run_name) os.makedirs(output_dir, exist_ok=True) logger.info("=" * 60) logger.info("FunctionGemma SFT LoRA training") logger.info("=" * 60) logger.info(f"Output dir: {output_dir}") # Save config config_path = os.path.join(output_dir, "training_config.json") with open(config_path, 'w') as f: json.dump(vars(args), f, indent=2) logger.info(f"Config saved to: {config_path}") # Load dataset train_dataset, eval_dataset = load_and_prepare_dataset( args.dataset_path, args.val_split ) # Load model + tokenizer model, tokenizer = load_model_and_tokenizer(args) # Build LoRA config if enabled if args.use_lora: logger.info("=" * 60) logger.info("LoRA fine-tuning mode") logger.info("=" * 60) lora_config = get_lora_config(args) else: logger.info("=" * 60) logger.info("Full-parameter fine-tuning mode") logger.info("Warning: full fine-tuning needs more memory and time!") logger.info("=" * 60) lora_config = None # SFTTrainer config training_args = SFTConfig( output_dir=output_dir, run_name=args.run_name, # Sequence length / packing max_length=args.max_seq_length, packing=False, # Training num_train_epochs=args.num_train_epochs, max_steps=args.max_steps, per_device_train_batch_size=args.per_device_train_batch_size, per_device_eval_batch_size=args.per_device_eval_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, # Optimizer learning_rate=args.learning_rate, weight_decay=args.weight_decay, warmup_ratio=args.warmup_ratio, lr_scheduler_type=args.lr_scheduler_type, optim="adamw_torch_fused", # Precision bf16=args.bf16, fp16=args.fp16, # Logging / saving logging_steps=args.logging_steps, save_steps=args.save_steps, eval_steps=args.eval_steps if eval_dataset else None, eval_strategy="steps" if eval_dataset else "no", save_total_limit=args.save_total_limit, load_best_model_at_end=True if eval_dataset else False, # Misc seed=args.seed, report_to=["tensorboard"], # Hub push_to_hub=args.push_to_hub, hub_model_id=args.hub_model_id, # Gradient checkpointing gradient_checkpointing=args.gradient_checkpointing, gradient_checkpointing_kwargs={"use_reentrant": False} if args.gradient_checkpointing else None, ) # Create SFTTrainer # Dataset should include 'messages' and 'tools'; SFTTrainer applies chat template automatically trainer = SFTTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=tokenizer, # newer TRL uses processing_class instead of tokenizer peft_config=lora_config, ) # Parameter stats trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) total_params = sum(p.numel() for p in model.parameters()) trainable_percentage = 100 * trainable_params / total_params if total_params > 0 else 0 logger.info("=" * 60) logger.info("Model parameter stats:") logger.info(f" Total params: {total_params:,}") logger.info(f" Trainable params: {trainable_params:,}") logger.info(f" Trainable ratio: {trainable_percentage:.2f}%") logger.info(f" Mode: {'LoRA' if args.use_lora else 'Full fine-tune'}") logger.info("=" * 60) # Train logger.info("Start training...") if args.resume_from_checkpoint: trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) else: trainer.train() # Save final model logger.info("Saving final model...") final_model_path = os.path.join(output_dir, "final_model") trainer.save_model(final_model_path) tokenizer.save_pretrained(final_model_path) logger.info("=" * 60) logger.info("Training done.") logger.info(f"Model saved at: {final_model_path}") if args.use_lora: # LoRA: also save adapter lora_path = os.path.join(output_dir, "lora_adapter") model.save_pretrained(lora_path) tokenizer.save_pretrained(lora_path) logger.info(f"LoRA adapter saved to: {lora_path}") logger.info("") logger.info("Usage:") logger.info(f" 1. LoRA adapter: {lora_path}") logger.info(f" 2. Merge adapters with your base model before inference") else: # Full fine-tune: final_model is ready to use logger.info("") logger.info("Usage:") logger.info(f" Use model directly from: {final_model_path}") logger.info("=" * 60) if __name__ == "__main__": main()