from dataclasses import dataclass, field, fields, asdict from typing import Optional, List, Literal, Dict, Any, Union from transformers import TrainingArguments, Trainer from omegaconf import OmegaConf import sys @dataclass class ModelConfig: model_name: str = "" dropout: float = 0.0 model_max_seq_length: int = field(default=512) data_collator_mode: str=field(default='fixed', metadata={"help": "fixed or dynamic padding in DataCollator"}) lambda_reg: float = field(default=1e-4, metadata={"help": "The control strength of regularity"}) adapter_path: Optional[str] = field(default=None) merge_adapter_path: Optional[str] = field(default=None) merge_output_path: Optional[str] = field(default=None) @dataclass class RotationConfig: r: int = field(default=4) num_rotations: int = field(default=4) task_type: str = "CAUSAL_LM" target_modules: List[str] = field(default_factory=lambda: ["q_proj",]) @dataclass class DataConfig: dataset_name: str = 'math' split_ratio: float = field(default=0.01) path: str = "./nl_tasks/data/MetaMathQA-40K/MetaMathQA-40K.json" dataset_split: str = field(default="train[:1000]", metadata={"help": "(`['train', 'test', 'eval']`):"}) adapter_names: List[Optional[str]] = field(default_factory=lambda: ["default"]) ### dataset_field: List[str] = field(default_factory=list, metadata={"help": "Fields of dataset input and output."}) @dataclass class TrainingOverride: optim: str=field(default="adamw_torch") ## eval_strategy: str=field(default='no') per_device_train_batch_size: int=field(default=8) ## per_device_eval_batch_size: int=field(default=8) ## learning_rate: float = field(default=1e-05) lr_scheduler_type: str = field(default='cosine') # warmup_ratio: float = field(default=0.1) warmup_steps: int = field(default=0) gradient_checkpointing: bool = field(default=False) gradient_accumulation_steps: int=field(default=1) output_dir: str = field(default="runs") save_steps: float = field(default=0) save_strategy: str =field(default='no') # save_total_limit: int=field(default=1) No need any more bf16: bool=field(default=False) bf16_full_eval: bool=field(default=False) save_safetensors: bool=field(default=False) report_to: Union[None, str, list[str]]=field(default="none") logging_steps: int=field(default=25) # we use int only # logging_first_step: bool=field(default=False) eval_steps: Union[None,int]=field(default=None) # we use int only f dataloader_num_workers: int = field(default=1) dataloader_pin_memory: bool = field(default=True) ### dataloader_persistent_workers: bool=field(default=True) ### dataloader_prefetch_factor: int = field(default=1) ### num_train_epochs: float = field(default=1.0) max_steps: int=field(default=-1) load_best_model_at_end: bool = field(default=True) @dataclass class GlueConfig: task_name: str = field(default='mnli') pad_to_max_length: bool = field(default=True) @dataclass class MainConfig: model: ModelConfig = field(default_factory=ModelConfig) rotation_adapter_config: RotationConfig = field(default_factory=RotationConfig) data: DataConfig = field(default_factory=DataConfig) trainer_args: TrainingOverride = field(default_factory=TrainingOverride) glue: GlueConfig = field(default_factory=GlueConfig) project_name: str = "llm_rotation" seed: int = 42 run_text: str=field(default='def') # device: str = field(default='cpu') @dataclass class HFTrainingArguments(TrainingArguments): extension: Optional[Dict[str, Any]] = field( default=None, metadata={"help": "Serialized MainConfig excluding training args"} ) def convert_to_trainer_args(main_cfg: MainConfig) -> HFTrainingArguments: """ Maps MainConfig to MyTrainingArguments. Logic: 1. Extract 'training' fields -> Pass to TrainingArguments constructor. 2. Pack 'model', 'data', etc. -> Put into 'extension'. """ KEY = "trainer_args" # 1. Convert OmegaConf/Dataclass to pure Python dict # resolve=True ensures variables like ${model.name} are interpolated full_dict = asdict(main_cfg) # 2. Extract Training Arguments # These will be unpack **kwargs to initialize the parent TrainingArguments train_args_dict = full_dict.pop(KEY) # 3. The rest (model, data, seed) goes into extension extension_payload = full_dict # 4. Initialize MyTrainingArguments # Note: We must ensure train_args_dict keys match TrainingArguments fields. try: args = HFTrainingArguments(**train_args_dict) except TypeError as e: print(f"Error: Your 'training' config contains keys unknown to HF TrainingArguments: {e}") sys.exit(1) # 5. Attach the extension args.extension = extension_payload return args @dataclass class Training: model_name_or_path: Optional[str] = field(default="huggyllama/llama-7b") adapter_name_or_path: Optional[str] = field(default=None) data_path: str = field(default=None, metadata={"help": "Path to the training data."}) dataset_split: str = field( default="train[:100000]", metadata={"help": "(`['train', 'test', 'eval']`):"} ) dataset_field: List[str] = field( default=None, metadata={"help": "Fields of dataset input and output."} ) optim: str = field(default="adamw_torch") model_max_length: int = field(default=512, metadata={ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, ) hrft_r: int = field(default=8, metadata={ "help": "The rank of the adapter. When passing `None` and `adapter_name_or_path` is also `None`, full fine-tuning is used."}) init_a: float = field(default=1e-4, metadata={"help": "The initial weights"}) eps: float = field(default=1e-4, metadata={"help": "The control strength of COFT. The freedom of rotation."}) lamda: float = field(default=1e-4, metadata={"help": "The control strength of regularity"}) add_orth: str = field(default='none', metadata={"help": ""}) init_weights: Literal[True, "pissa"] = field( default=True, metadata={ "help": ( "Passing True (default) results in the LoRA initialization." "Passing `pissa` results in PiSSA initialization." ), }, ) extension: Optional[Dict[str, Any]] = field( default=None, metadata={"help": "Serialized MainConfig excluding training args"} ) # target_modules: str = ( # "(.*x_embedder" # "|.*(?