|
|
from dataclasses import dataclass, field, fields, asdict |
|
|
from typing import Optional, List, Literal, Dict, Any, Union |
|
|
from transformers import TrainingArguments, Trainer |
|
|
from omegaconf import OmegaConf |
|
|
import sys |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelConfig: |
|
|
model_name: str = "" |
|
|
dropout: float = 0.0 |
|
|
model_max_seq_length: int = field(default=512) |
|
|
data_collator_mode: str=field(default='fixed', metadata={"help": "fixed or dynamic padding in DataCollator"}) |
|
|
lambda_reg: float = field(default=1e-4, metadata={"help": "The control strength of regularity"}) |
|
|
adapter_path: Optional[str] = field(default=None) |
|
|
|
|
|
merge_adapter_path: Optional[str] = field(default=None) |
|
|
merge_output_path: Optional[str] = field(default=None) |
|
|
|
|
|
@dataclass |
|
|
class RotationConfig: |
|
|
r: int = field(default=4) |
|
|
num_rotations: int = field(default=4) |
|
|
task_type: str = "CAUSAL_LM" |
|
|
target_modules: List[str] = field(default_factory=lambda: ["q_proj",]) |
|
|
|
|
|
@dataclass |
|
|
class DataConfig: |
|
|
dataset_name: str = 'math' |
|
|
split_ratio: float = field(default=0.01) |
|
|
path: str = "./nl_tasks/data/MetaMathQA-40K/MetaMathQA-40K.json" |
|
|
dataset_split: str = field(default="train[:1000]", metadata={"help": "(`['train', 'test', 'eval']`):"}) |
|
|
adapter_names: List[Optional[str]] = field(default_factory=lambda: ["default"]) |
|
|
dataset_field: List[str] = field(default_factory=list, metadata={"help": "Fields of dataset input and output."}) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainingOverride: |
|
|
optim: str=field(default="adamw_torch") |
|
|
eval_strategy: str=field(default='no') |
|
|
per_device_train_batch_size: int=field(default=8) |
|
|
per_device_eval_batch_size: int=field(default=8) |
|
|
|
|
|
learning_rate: float = field(default=1e-05) |
|
|
lr_scheduler_type: str = field(default='cosine') |
|
|
|
|
|
warmup_steps: int = field(default=0) |
|
|
|
|
|
gradient_checkpointing: bool = field(default=False) |
|
|
gradient_accumulation_steps: int=field(default=1) |
|
|
output_dir: str = field(default="runs") |
|
|
save_steps: float = field(default=0) |
|
|
save_strategy: str =field(default='no') |
|
|
|
|
|
bf16: bool=field(default=False) |
|
|
bf16_full_eval: bool=field(default=False) |
|
|
save_safetensors: bool=field(default=False) |
|
|
|
|
|
report_to: Union[None, str, list[str]]=field(default="none") |
|
|
logging_steps: int=field(default=25) |
|
|
|
|
|
eval_steps: Union[None,int]=field(default=None) |
|
|
|
|
|
dataloader_num_workers: int = field(default=1) |
|
|
dataloader_pin_memory: bool = field(default=True) |
|
|
dataloader_persistent_workers: bool=field(default=True) |
|
|
dataloader_prefetch_factor: int = field(default=1) |
|
|
|
|
|
num_train_epochs: float = field(default=1.0) |
|
|
max_steps: int=field(default=-1) |
|
|
load_best_model_at_end: bool = field(default=True) |
|
|
|
|
|
@dataclass |
|
|
class GlueConfig: |
|
|
task_name: str = field(default='mnli') |
|
|
pad_to_max_length: bool = field(default=True) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class MainConfig: |
|
|
model: ModelConfig = field(default_factory=ModelConfig) |
|
|
rotation_adapter_config: RotationConfig = field(default_factory=RotationConfig) |
|
|
data: DataConfig = field(default_factory=DataConfig) |
|
|
trainer_args: TrainingOverride = field(default_factory=TrainingOverride) |
|
|
|
|
|
glue: GlueConfig = field(default_factory=GlueConfig) |
|
|
project_name: str = "llm_rotation" |
|
|
seed: int = 42 |
|
|
run_text: str=field(default='def') |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class HFTrainingArguments(TrainingArguments): |
|
|
extension: Optional[Dict[str, Any]] = field( |
|
|
default=None, |
|
|
metadata={"help": "Serialized MainConfig excluding training args"} |
|
|
) |
|
|
|
|
|
def convert_to_trainer_args(main_cfg: MainConfig) -> HFTrainingArguments: |
|
|
""" |
|
|
Maps MainConfig to MyTrainingArguments. |
|
|
Logic: |
|
|
1. Extract 'training' fields -> Pass to TrainingArguments constructor. |
|
|
2. Pack 'model', 'data', etc. -> Put into 'extension'. |
|
|
""" |
|
|
KEY = "trainer_args" |
|
|
|
|
|
|
|
|
full_dict = asdict(main_cfg) |
|
|
|
|
|
|
|
|
|
|
|
train_args_dict = full_dict.pop(KEY) |
|
|
|
|
|
|
|
|
extension_payload = full_dict |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
args = HFTrainingArguments(**train_args_dict) |
|
|
except TypeError as e: |
|
|
print(f"Error: Your 'training' config contains keys unknown to HF TrainingArguments: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
args.extension = extension_payload |
|
|
|
|
|
return args |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Training: |
|
|
model_name_or_path: Optional[str] = field(default="huggyllama/llama-7b") |
|
|
adapter_name_or_path: Optional[str] = field(default=None) |
|
|
data_path: str = field(default=None, metadata={"help": "Path to the training data."}) |
|
|
dataset_split: str = field( |
|
|
default="train[:100000]", metadata={"help": "(`['train', 'test', 'eval']`):"} |
|
|
) |
|
|
dataset_field: List[str] = field( |
|
|
default=None, metadata={"help": "Fields of dataset input and output."} |
|
|
) |
|
|
optim: str = field(default="adamw_torch") |
|
|
model_max_length: int = field(default=512, metadata={ |
|
|
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, ) |
|
|
hrft_r: int = field(default=8, metadata={ |
|
|
"help": "The rank of the adapter. When passing `None` and `adapter_name_or_path` is also `None`, full fine-tuning is used."}) |
|
|
init_a: float = field(default=1e-4, metadata={"help": "The initial weights"}) |
|
|
eps: float = field(default=1e-4, metadata={"help": "The control strength of COFT. The freedom of rotation."}) |
|
|
lamda: float = field(default=1e-4, metadata={"help": "The control strength of regularity"}) |
|
|
add_orth: str = field(default='none', metadata={"help": ""}) |
|
|
init_weights: Literal[True, "pissa"] = field( |
|
|
default=True, |
|
|
metadata={ |
|
|
"help": ( |
|
|
"Passing True (default) results in the LoRA initialization." |
|
|
"Passing `pissa` results in PiSSA initialization." |
|
|
), |
|
|
}, |
|
|
) |
|
|
extension: Optional[Dict[str, Any]] = field( |
|
|
default=None, |
|
|
metadata={"help": "Serialized MainConfig excluding training args"} |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|