nvan15's picture
Batch upload part 2
6bb0065 verified
from dataclasses import dataclass, field, fields, asdict
from typing import Optional, List, Literal, Dict, Any, Union
from transformers import TrainingArguments, Trainer
from omegaconf import OmegaConf
import sys
@dataclass
class ModelConfig:
model_name: str = ""
dropout: float = 0.0
model_max_seq_length: int = field(default=512)
data_collator_mode: str=field(default='fixed', metadata={"help": "fixed or dynamic padding in DataCollator"})
lambda_reg: float = field(default=1e-4, metadata={"help": "The control strength of regularity"})
adapter_path: Optional[str] = field(default=None)
merge_adapter_path: Optional[str] = field(default=None)
merge_output_path: Optional[str] = field(default=None)
@dataclass
class RotationConfig:
r: int = field(default=4)
num_rotations: int = field(default=4)
task_type: str = "CAUSAL_LM"
target_modules: List[str] = field(default_factory=lambda: ["q_proj",])
@dataclass
class DataConfig:
dataset_name: str = 'math'
split_ratio: float = field(default=0.01)
path: str = "./nl_tasks/data/MetaMathQA-40K/MetaMathQA-40K.json"
dataset_split: str = field(default="train[:1000]", metadata={"help": "(`['train', 'test', 'eval']`):"})
adapter_names: List[Optional[str]] = field(default_factory=lambda: ["default"]) ###
dataset_field: List[str] = field(default_factory=list, metadata={"help": "Fields of dataset input and output."})
@dataclass
class TrainingOverride:
optim: str=field(default="adamw_torch") ##
eval_strategy: str=field(default='no')
per_device_train_batch_size: int=field(default=8) ##
per_device_eval_batch_size: int=field(default=8) ##
learning_rate: float = field(default=1e-05)
lr_scheduler_type: str = field(default='cosine')
# warmup_ratio: float = field(default=0.1)
warmup_steps: int = field(default=0)
gradient_checkpointing: bool = field(default=False)
gradient_accumulation_steps: int=field(default=1)
output_dir: str = field(default="runs")
save_steps: float = field(default=0)
save_strategy: str =field(default='no')
# save_total_limit: int=field(default=1) No need any more
bf16: bool=field(default=False)
bf16_full_eval: bool=field(default=False)
save_safetensors: bool=field(default=False)
report_to: Union[None, str, list[str]]=field(default="none")
logging_steps: int=field(default=25) # we use int only
# logging_first_step: bool=field(default=False)
eval_steps: Union[None,int]=field(default=None) # we use int only f
dataloader_num_workers: int = field(default=1)
dataloader_pin_memory: bool = field(default=True) ###
dataloader_persistent_workers: bool=field(default=True) ###
dataloader_prefetch_factor: int = field(default=1) ###
num_train_epochs: float = field(default=1.0)
max_steps: int=field(default=-1)
load_best_model_at_end: bool = field(default=True)
@dataclass
class GlueConfig:
task_name: str = field(default='mnli')
pad_to_max_length: bool = field(default=True)
@dataclass
class MainConfig:
model: ModelConfig = field(default_factory=ModelConfig)
rotation_adapter_config: RotationConfig = field(default_factory=RotationConfig)
data: DataConfig = field(default_factory=DataConfig)
trainer_args: TrainingOverride = field(default_factory=TrainingOverride)
glue: GlueConfig = field(default_factory=GlueConfig)
project_name: str = "llm_rotation"
seed: int = 42
run_text: str=field(default='def')
# device: str = field(default='cpu')
@dataclass
class HFTrainingArguments(TrainingArguments):
extension: Optional[Dict[str, Any]] = field(
default=None,
metadata={"help": "Serialized MainConfig excluding training args"}
)
def convert_to_trainer_args(main_cfg: MainConfig) -> HFTrainingArguments:
"""
Maps MainConfig to MyTrainingArguments.
Logic:
1. Extract 'training' fields -> Pass to TrainingArguments constructor.
2. Pack 'model', 'data', etc. -> Put into 'extension'.
"""
KEY = "trainer_args"
# 1. Convert OmegaConf/Dataclass to pure Python dict
# resolve=True ensures variables like ${model.name} are interpolated
full_dict = asdict(main_cfg)
# 2. Extract Training Arguments
# These will be unpack **kwargs to initialize the parent TrainingArguments
train_args_dict = full_dict.pop(KEY)
# 3. The rest (model, data, seed) goes into extension
extension_payload = full_dict
# 4. Initialize MyTrainingArguments
# Note: We must ensure train_args_dict keys match TrainingArguments fields.
try:
args = HFTrainingArguments(**train_args_dict)
except TypeError as e:
print(f"Error: Your 'training' config contains keys unknown to HF TrainingArguments: {e}")
sys.exit(1)
# 5. Attach the extension
args.extension = extension_payload
return args
@dataclass
class Training:
model_name_or_path: Optional[str] = field(default="huggyllama/llama-7b")
adapter_name_or_path: Optional[str] = field(default=None)
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
dataset_split: str = field(
default="train[:100000]", metadata={"help": "(`['train', 'test', 'eval']`):"}
)
dataset_field: List[str] = field(
default=None, metadata={"help": "Fields of dataset input and output."}
)
optim: str = field(default="adamw_torch")
model_max_length: int = field(default=512, metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, )
hrft_r: int = field(default=8, metadata={
"help": "The rank of the adapter. When passing `None` and `adapter_name_or_path` is also `None`, full fine-tuning is used."})
init_a: float = field(default=1e-4, metadata={"help": "The initial weights"})
eps: float = field(default=1e-4, metadata={"help": "The control strength of COFT. The freedom of rotation."})
lamda: float = field(default=1e-4, metadata={"help": "The control strength of regularity"})
add_orth: str = field(default='none', metadata={"help": ""})
init_weights: Literal[True, "pissa"] = field(
default=True,
metadata={
"help": (
"Passing True (default) results in the LoRA initialization."
"Passing `pissa` results in PiSSA initialization."
),
},
)
extension: Optional[Dict[str, Any]] = field(
default=None,
metadata={"help": "Serialized MainConfig excluding training args"}
)
# target_modules: str = (
# "(.*x_embedder"
# "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear"
# "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k"
# "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q"
# "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v"
# "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0"
# "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2"
# "|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear"
# "|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp"
# "|.*single_transformer_blocks\\.[0-9]+\\.proj_out"
# "|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k"
# "|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q"
# "|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v"
# "|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)"
# )