File size: 7,555 Bytes
6bb0065
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
from dataclasses import dataclass, field, fields, asdict
from typing import Optional, List, Literal, Dict, Any, Union
from transformers import TrainingArguments, Trainer
from omegaconf import OmegaConf
import sys


@dataclass
class ModelConfig:
    model_name: str = ""
    dropout: float = 0.0
    model_max_seq_length: int = field(default=512)
    data_collator_mode: str=field(default='fixed', metadata={"help": "fixed or dynamic padding in DataCollator"})  
    lambda_reg: float = field(default=1e-4, metadata={"help": "The control strength of regularity"})
    adapter_path: Optional[str] = field(default=None)

    merge_adapter_path: Optional[str] = field(default=None)
    merge_output_path: Optional[str] = field(default=None)

@dataclass
class RotationConfig:
    r: int = field(default=4)
    num_rotations: int = field(default=4)
    task_type: str = "CAUSAL_LM"
    target_modules: List[str] = field(default_factory=lambda: ["q_proj",])

@dataclass
class DataConfig:
    dataset_name: str = 'math'
    split_ratio: float = field(default=0.01)
    path: str = "./nl_tasks/data/MetaMathQA-40K/MetaMathQA-40K.json"
    dataset_split: str = field(default="train[:1000]", metadata={"help": "(`['train', 'test', 'eval']`):"})
    adapter_names: List[Optional[str]] = field(default_factory=lambda: ["default"])   ###
    dataset_field: List[str] = field(default_factory=list, metadata={"help": "Fields of dataset input and output."})


@dataclass
class TrainingOverride:
    optim: str=field(default="adamw_torch")   ##
    eval_strategy: str=field(default='no')
    per_device_train_batch_size: int=field(default=8) ##
    per_device_eval_batch_size: int=field(default=8)  ##

    learning_rate: float = field(default=1e-05)
    lr_scheduler_type: str = field(default='cosine')
    # warmup_ratio: float = field(default=0.1)
    warmup_steps: int = field(default=0)
    
    gradient_checkpointing: bool = field(default=False)
    gradient_accumulation_steps: int=field(default=1)
    output_dir: str = field(default="runs")
    save_steps: float = field(default=0)
    save_strategy: str =field(default='no')
    # save_total_limit: int=field(default=1) No need any more
    bf16: bool=field(default=False)
    bf16_full_eval: bool=field(default=False)
    save_safetensors: bool=field(default=False)

    report_to: Union[None, str, list[str]]=field(default="none")
    logging_steps: int=field(default=25) # we use int only
    # logging_first_step: bool=field(default=False)
    eval_steps: Union[None,int]=field(default=None)  # we use int only f

    dataloader_num_workers: int = field(default=1)
    dataloader_pin_memory: bool = field(default=True)  ###
    dataloader_persistent_workers: bool=field(default=True) ###
    dataloader_prefetch_factor: int = field(default=1) ###

    num_train_epochs: float = field(default=1.0)
    max_steps: int=field(default=-1)
    load_best_model_at_end: bool = field(default=True)

@dataclass
class GlueConfig:
    task_name: str = field(default='mnli')
    pad_to_max_length: bool = field(default=True)


@dataclass
class MainConfig:
    model: ModelConfig = field(default_factory=ModelConfig)
    rotation_adapter_config: RotationConfig = field(default_factory=RotationConfig)
    data: DataConfig = field(default_factory=DataConfig)
    trainer_args: TrainingOverride = field(default_factory=TrainingOverride)

    glue: GlueConfig = field(default_factory=GlueConfig)
    project_name: str = "llm_rotation"
    seed: int = 42
    run_text: str=field(default='def')
    # device: str = field(default='cpu')

@dataclass
class HFTrainingArguments(TrainingArguments):
    extension: Optional[Dict[str, Any]] = field(
        default=None, 
        metadata={"help": "Serialized MainConfig excluding training args"}
    )

def convert_to_trainer_args(main_cfg: MainConfig) -> HFTrainingArguments:
    """
    Maps MainConfig to MyTrainingArguments.
    Logic:
    1. Extract 'training' fields -> Pass to TrainingArguments constructor.
    2. Pack 'model', 'data', etc. -> Put into 'extension'.
    """
    KEY = "trainer_args"
    # 1. Convert OmegaConf/Dataclass to pure Python dict
    # resolve=True ensures variables like ${model.name} are interpolated
    full_dict = asdict(main_cfg)
    
    # 2. Extract Training Arguments
    # These will be unpack **kwargs to initialize the parent TrainingArguments
    train_args_dict = full_dict.pop(KEY)
    
    # 3. The rest (model, data, seed) goes into extension
    extension_payload = full_dict
    
    # 4. Initialize MyTrainingArguments
    # Note: We must ensure train_args_dict keys match TrainingArguments fields.
    try:
        args = HFTrainingArguments(**train_args_dict)
    except TypeError as e:
        print(f"Error: Your 'training' config contains keys unknown to HF TrainingArguments: {e}")
        sys.exit(1)
        
    # 5. Attach the extension
    args.extension = extension_payload
    
    return args




@dataclass
class Training:
    model_name_or_path: Optional[str] = field(default="huggyllama/llama-7b")
    adapter_name_or_path: Optional[str] = field(default=None)
    data_path: str = field(default=None, metadata={"help": "Path to the training data."})
    dataset_split: str = field(
        default="train[:100000]", metadata={"help": "(`['train', 'test', 'eval']`):"}
    )
    dataset_field: List[str] = field(
        default=None, metadata={"help": "Fields of dataset input and output."}
    )
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(default=512, metadata={
        "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, )
    hrft_r: int = field(default=8, metadata={
        "help": "The rank of the adapter. When passing `None` and `adapter_name_or_path` is also `None`, full fine-tuning is used."})
    init_a: float = field(default=1e-4, metadata={"help": "The initial weights"})
    eps: float = field(default=1e-4, metadata={"help": "The control strength of COFT. The freedom of rotation."})
    lamda: float = field(default=1e-4, metadata={"help": "The control strength of regularity"})
    add_orth: str = field(default='none', metadata={"help": ""})
    init_weights: Literal[True, "pissa"] = field(
        default=True,
        metadata={
            "help": (
                "Passing True (default) results in the LoRA initialization."
                "Passing `pissa` results in PiSSA initialization."
            ),
        },
    )
    extension: Optional[Dict[str, Any]] = field(
        default=None, 
        metadata={"help": "Serialized MainConfig excluding training args"}
    )

    # target_modules: str = (
    #     "(.*x_embedder"
    #     "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear"
    #     "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k"
    #     "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q"
    #     "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v"
    #     "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0"
    #     "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2"
    #     "|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear"
    #     "|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp"
    #     "|.*single_transformer_blocks\\.[0-9]+\\.proj_out"
    #     "|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k"
    #     "|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q"
    #     "|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v"
    #     "|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)"
    # )