|
|
|
|
|
import sys |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import warnings |
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=FutureWarning, module='torch._inductor.lowering') |
|
|
warnings.filterwarnings("ignore", message=".*Online softmax is disabled on the fly.*", category=UserWarning) |
|
|
|
|
|
warnings.filterwarnings("ignore", message=".*Our suggested max number of worker in current system is 1.*", category=UserWarning) |
|
|
warnings.filterwarnings("ignore", message=".*will be initialized from a multivariate normal distribution.*") |
|
|
warnings.filterwarnings("ignore", message=".*that differ from the model config and generation config.*", category=UserWarning) |
|
|
warnings.filterwarnings("ignore", message=".*torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch..*", category=UserWarning) |
|
|
|
|
|
import torch |
|
|
torch.backends.cuda.matmul.fp32_precision = 'tf32' |
|
|
|
|
|
import os |
|
|
torch.set_num_threads(1) |
|
|
os.environ["OMP_NUM_THREADS"]="1" |
|
|
os.environ["MKL_NUM_THREADS"]="1" |
|
|
import torch |
|
|
print(f"PyTorch version: {torch.__version__}") |
|
|
print(f"CUDA available: {torch.cuda.is_available()}") |
|
|
print(f"PyTorch built with CUDA version: {torch.version.cuda}") |
|
|
|
|
|
import yaml |
|
|
|
|
|
from torch.utils.data import DataLoader |
|
|
import time |
|
|
from datetime import datetime |
|
|
import math |
|
|
|
|
|
from typing import List, Tuple |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
from dataclasses import field, dataclass, asdict |
|
|
from typing import Sequence, Literal, Dict |
|
|
|
|
|
import transformers |
|
|
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer |
|
|
from transformers import Trainer |
|
|
from transformers.modeling_utils import * |
|
|
from transformers.trainer import _is_peft_model |
|
|
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES |
|
|
from transformers.data.data_collator import DataCollator |
|
|
|
|
|
from transformers.training_args import TrainingArguments |
|
|
from transformers.tokenization_utils_base import PreTrainedTokenizerBase |
|
|
from transformers.trainer_callback import TrainerCallback |
|
|
from transformers.trainer_utils import EvalPrediction |
|
|
from torch.utils.data import Dataset, IterableDataset |
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from rpeft.rotation import RotationTuner, RotationConfig |
|
|
from rpeft import get_peft_model, PeftModel |
|
|
from .config import MainConfig, convert_to_trainer_args |
|
|
import pyrallis |
|
|
from omegaconf import OmegaConf |
|
|
import torch.optim as optim |
|
|
import wandb |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
|
|
IGNORE_INDEX = -100 |
|
|
PROMPT = ( |
|
|
"Below is an instruction that describes a task. " |
|
|
"Write a response that appropriately completes the request.\n\n" |
|
|
"### Instruction:\n{instruction}\n\n### Response:" |
|
|
) |
|
|
|
|
|
|
|
|
import platform |
|
|
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl |
|
|
class ExperimentMonitorCallback(TrainerCallback): |
|
|
""" |
|
|
Callback to monitor training performance and log system stats to a JSON file. |
|
|
It captures: |
|
|
1. Experiment Metadata (GPU info, Batch size, Learning rate, etc.) |
|
|
2. Runtime Metrics (Avg time/step, Throughput) |
|
|
3. Memory Metrics (Allocated, Reserved, and Peak usage) |
|
|
""" |
|
|
|
|
|
def __init__(self, log_file_path: str, run_name: str = "experiment", log_interval: int = 100): |
|
|
|
|
|
self.log_file_path = log_file_path |
|
|
self.run_name = run_name |
|
|
self.log_interval = log_interval |
|
|
|
|
|
|
|
|
self.start_time = None |
|
|
self.last_log_time = None |
|
|
|
|
|
|
|
|
self.log_data = { |
|
|
"metadata": {}, |
|
|
"metrics": [] |
|
|
} |
|
|
|
|
|
def _get_gpu_info(self): |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
return { |
|
|
"name": torch.cuda.get_device_name(0), |
|
|
"count": torch.cuda.device_count(), |
|
|
"capability": torch.cuda.get_device_capability(0) |
|
|
} |
|
|
return "CPU_ONLY" |
|
|
|
|
|
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): |
|
|
|
|
|
self.start_time = time.perf_counter() |
|
|
self.last_log_time = self.start_time |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.reset_peak_memory_stats() |
|
|
|
|
|
|
|
|
self.log_data["metadata"] = { |
|
|
"run_name": self.run_name, |
|
|
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
|
"python_version": platform.python_version(), |
|
|
"pytorch_version": torch.__version__, |
|
|
"gpu_info": self._get_gpu_info(), |
|
|
"configuration": { |
|
|
"batch_size_per_device": args.per_device_train_batch_size, |
|
|
"learning_rate": args.learning_rate, |
|
|
"max_steps": args.max_steps, |
|
|
"num_train_epochs": args.num_train_epochs, |
|
|
"fp16": args.fp16, |
|
|
"bf16": args.bf16, |
|
|
"optim": args.optim, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
self._save_log() |
|
|
|
|
|
|
|
|
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): |
|
|
current_step = state.global_step |
|
|
|
|
|
|
|
|
if current_step > 0 and current_step % self.log_interval == 0: |
|
|
current_time = time.perf_counter() |
|
|
|
|
|
|
|
|
elapsed_since_last = current_time - self.last_log_time |
|
|
avg_time_per_step = elapsed_since_last / self.log_interval |
|
|
|
|
|
|
|
|
mem_stats = {} |
|
|
if torch.cuda.is_available(): |
|
|
|
|
|
mem_stats["allocated_gb"] = torch.cuda.memory_allocated() / 1024**3 |
|
|
mem_stats["reserved_gb"] = torch.cuda.memory_reserved() / 1024**3 |
|
|
|
|
|
mem_stats["peak_allocated_gb"] = torch.cuda.max_memory_allocated() / 1024**3 |
|
|
|
|
|
|
|
|
metric_entry = { |
|
|
"step": current_step, |
|
|
"epoch": state.epoch, |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"performance": { |
|
|
"avg_time_per_step_s": round(avg_time_per_step, 4), |
|
|
"steps_per_second": round(1.0 / avg_time_per_step, 2) |
|
|
}, |
|
|
"memory": mem_stats |
|
|
} |
|
|
|
|
|
|
|
|
self.log_data["metrics"].append(metric_entry) |
|
|
self._save_log() |
|
|
|
|
|
|
|
|
self.last_log_time = current_time |
|
|
|
|
|
|
|
|
print(f" -> Step {current_step}: {avg_time_per_step*1000:.1f}s/step |"\ |
|
|
f"Peak Mem: {mem_stats.get('peak_allocated_gb', 0):.2f} GB |"\ |
|
|
f"Reserved: {mem_stats.get('reserved_gb', 0):.2f} GB") |
|
|
|
|
|
def _save_log(self): |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
with open(self.log_file_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(self.log_data, f, indent=4) |
|
|
except Exception as e: |
|
|
print(f"Error saving experiment log: {e}") |
|
|
|
|
|
def get_rank(): |
|
|
try: |
|
|
rank = int(os.environ.get("LOCAL_RANK")) |
|
|
except: |
|
|
rank = 0 |
|
|
return rank |
|
|
|
|
|
|
|
|
def get_config(): |
|
|
config_path = os.environ.get("OMINI_CONFIG") |
|
|
assert config_path is not None, "Please set the OMINI_CONFIG environment variable" |
|
|
with open(config_path, "r") as f: |
|
|
config = yaml.safe_load(f) |
|
|
return config |
|
|
|
|
|
|
|
|
def init_wandb(wandb_config, run_name): |
|
|
import wandb |
|
|
|
|
|
try: |
|
|
assert os.environ.get("WANDB_API_KEY") is not None |
|
|
wandb.init( |
|
|
project=wandb_config["project"], |
|
|
name=run_name, |
|
|
config={}, |
|
|
) |
|
|
except Exception as e: |
|
|
print("Failed to initialize WanDB:", e) |
|
|
|
|
|
|
|
|
|
|
|
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): |
|
|
"""Collects the state dict and dump to disk.""" |
|
|
state_dict = trainer.model.state_dict() |
|
|
if trainer.args.should_save: |
|
|
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} |
|
|
del state_dict |
|
|
trainer._save(output_dir, state_dict=cpu_state_dict) |
|
|
|
|
|
|
|
|
def smart_tokenizer_and_embedding_resize( |
|
|
special_tokens_dict: Dict, |
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
model: transformers.PreTrainedModel, |
|
|
): |
|
|
"""Resize tokenizer and embedding. |
|
|
|
|
|
Note: This is the unoptimized version that may make your embedding size not be divisible by 64. |
|
|
""" |
|
|
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) |
|
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
|
if num_new_tokens > 0: |
|
|
input_embeddings = model.get_input_embeddings().weight.data |
|
|
output_embeddings = model.get_output_embeddings().weight.data |
|
|
|
|
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) |
|
|
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) |
|
|
|
|
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg |
|
|
output_embeddings[-num_new_tokens:] = output_embeddings_avg |
|
|
|
|
|
|
|
|
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: |
|
|
"""Tokenize a list of strings.""" |
|
|
tokenized_list = [ |
|
|
tokenizer( |
|
|
text, |
|
|
return_tensors="pt", |
|
|
padding="longest", |
|
|
max_length=tokenizer.model_max_length, |
|
|
truncation=True, |
|
|
) |
|
|
for text in strings |
|
|
] |
|
|
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] |
|
|
input_ids_lens = labels_lens = [ |
|
|
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list |
|
|
] |
|
|
return dict( |
|
|
input_ids=input_ids, |
|
|
labels=labels, |
|
|
input_ids_lens=input_ids_lens, |
|
|
labels_lens=labels_lens, |
|
|
) |
|
|
|
|
|
def preprocess( |
|
|
sources: Sequence[str], |
|
|
targets: Sequence[str], |
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
) -> Dict: |
|
|
"""Preprocess the data by tokenizing.""" |
|
|
examples = [s + t for s, t in zip(sources, targets)] |
|
|
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)] |
|
|
input_ids = examples_tokenized["input_ids"] |
|
|
labels = copy.deepcopy(input_ids) |
|
|
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): |
|
|
label[:source_len] = IGNORE_INDEX |
|
|
return dict(input_ids=input_ids, labels=labels) |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DataCollatorForSupervisedDataset(): |
|
|
tokenizer: transformers.PreTrainedTokenizer |
|
|
max_length: int = field(default=512) |
|
|
mode: str = field(default="fixed") |
|
|
|
|
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
|
|
|
input_ids_list = [torch.tensor(x["input_ids"], dtype=torch.long) for x in instances] |
|
|
labels_list = [torch.tensor(x["labels"], dtype=torch.long) for x in instances] |
|
|
|
|
|
|
|
|
if self.mode == "dynamic": |
|
|
|
|
|
|
|
|
batch_max_len = max([len(x) for x in input_ids_list]) |
|
|
target_len = min(batch_max_len, self.max_length) |
|
|
else: |
|
|
|
|
|
target_len = self.max_length |
|
|
|
|
|
|
|
|
def pad_and_truncate(tensors, padding_value): |
|
|
|
|
|
padded = pad_sequence(tensors, batch_first=True, padding_value=padding_value) |
|
|
|
|
|
|
|
|
curr_len = padded.shape[1] |
|
|
if curr_len > target_len: |
|
|
|
|
|
return padded[:, :target_len] |
|
|
elif curr_len < target_len: |
|
|
|
|
|
diff = target_len - curr_len |
|
|
padding = torch.full((padded.shape[0], diff), padding_value, dtype=padded.dtype) |
|
|
return torch.cat([padded, padding], dim=1) |
|
|
else: |
|
|
return padded |
|
|
|
|
|
|
|
|
|
|
|
if self.tokenizer.pad_token_id is None: |
|
|
raise ValueError("Tokenizer.pad_token_id is None. Please set it to eos_token_id or unk_token_id.") |
|
|
|
|
|
input_ids = pad_and_truncate(input_ids_list, self.tokenizer.pad_token_id) |
|
|
labels = pad_and_truncate(labels_list, IGNORE_INDEX) |
|
|
|
|
|
|
|
|
|
|
|
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long() |
|
|
|
|
|
return { |
|
|
"input_ids": input_ids, |
|
|
"labels": labels, |
|
|
"attention_mask": attention_mask |
|
|
} |
|
|
|
|
|
def train_tokenize_function(examples, tokenizer, query, response): |
|
|
sources = [PROMPT.format_map(dict(instruction=instruction)) for instruction in examples[query]] |
|
|
targets = [f"{output}{tokenizer.eos_token}" for output in examples[response]] |
|
|
data_dict = preprocess(sources, targets, tokenizer) |
|
|
return data_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def default_worker_init_fn(worker_id): |
|
|
|
|
|
try: |
|
|
import numpy as _np |
|
|
except Exception: |
|
|
_np = None |
|
|
torch.set_num_threads(1) |
|
|
os.environ.setdefault("OMP_NUM_THREADS", "1") |
|
|
os.environ.setdefault("MKL_NUM_THREADS", "1") |
|
|
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1") |
|
|
|
|
|
try: |
|
|
cpu_count = os.cpu_count() or 1 |
|
|
|
|
|
num_workers = getattr(torch.utils.data, "_num_workers", None) |
|
|
|
|
|
|
|
|
|
|
|
chunk = max(1, cpu_count // max(1, min(64, cpu_count))) |
|
|
start = (worker_id * chunk) % cpu_count |
|
|
end = start + chunk |
|
|
mask = set(range(start, min(end, cpu_count))) |
|
|
try: |
|
|
os.sched_setaffinity(0, mask) |
|
|
except Exception: |
|
|
pass |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
def set_seed(seed: int): |
|
|
|
|
|
|
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
transformers.set_seed(seed) |
|
|
|
|
|
|
|
|
@pyrallis.wrap() |
|
|
def main(mainCfg: MainConfig): |
|
|
|
|
|
|
|
|
print('='*120) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
set_seed(mainCfg.seed) |
|
|
training_args = convert_to_trainer_args(mainCfg) |
|
|
|
|
|
|
|
|
ENTITY = "nvan-13-korea-university" |
|
|
PROJECT = os.environ.get("WANDB_PROJECT") |
|
|
api = wandb.Api() |
|
|
try: |
|
|
runs_list = api.runs(f"{ENTITY}/{PROJECT}") |
|
|
next_run_num = len(runs_list) + 1 |
|
|
except Exception as e: |
|
|
next_run_num = 1 |
|
|
|
|
|
training_args.run_name = f'[{next_run_num}]lr={mainCfg.trainer_args.learning_rate:.1e},b={mainCfg.trainer_args.per_device_train_batch_size},'\ |
|
|
f'n={mainCfg.rotation_adapter_config.num_rotations},r={mainCfg.rotation_adapter_config.r},'\ |
|
|
f'init={mainCfg.run_text}' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(mainCfg.model.model_name, |
|
|
device_map="auto", low_cpu_mem_usage=True, |
|
|
dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, |
|
|
|
|
|
) |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print("DEVICE", model.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
total_params_now = sum(p.numel() for p in model.parameters()) |
|
|
print(f'#params of the pretrained model, {total_params_now:,}') |
|
|
|
|
|
if mainCfg.model.adapter_path is not None: |
|
|
print('___ Loading from: ', mainCfg.model.adapter_path) |
|
|
model = PeftModel.from_pretrained(model, mainCfg.model.adapter_path, is_trainable = True) |
|
|
elif mainCfg.rotation_adapter_config.r is not None: |
|
|
import peft |
|
|
if mainCfg.run_text == 'loco': |
|
|
rotation_adapter_config = asdict(mainCfg.rotation_adapter_config) |
|
|
|
|
|
for adapter_name in mainCfg.data.adapter_names: |
|
|
rotation_config = RotationConfig(**rotation_adapter_config) |
|
|
model = get_peft_model(model, rotation_config, adapter_name=adapter_name) |
|
|
print('loaded a LoCo model, batch = ', training_args.per_device_train_batch_size) |
|
|
elif mainCfg.run_text == 'boft': |
|
|
from peft import BOFTConfig |
|
|
boft_config = BOFTConfig( |
|
|
boft_block_size=mainCfg.rotation_adapter_config.r, |
|
|
boft_n_butterfly_factor=2*mainCfg.rotation_adapter_config.num_rotations, |
|
|
target_modules=["q_proj", "v_proj",], |
|
|
boft_dropout=0.05, |
|
|
bias="none", |
|
|
|
|
|
) |
|
|
|
|
|
for adapter_name in mainCfg.data.adapter_names: |
|
|
model = peft.get_peft_model(model, boft_config, adapter_name=adapter_name) |
|
|
print('loaded a BOFT model, batch = ', training_args.per_device_train_batch_size) |
|
|
elif mainCfg.run_text == 'hra': |
|
|
from peft import HRAConfig |
|
|
hra_config = HRAConfig( |
|
|
r=2*mainCfg.rotation_adapter_config.r, |
|
|
target_modules=["q_proj", "v_proj",], |
|
|
init_weights=True, |
|
|
|
|
|
) |
|
|
|
|
|
for adapter_name in mainCfg.data.adapter_names: |
|
|
model = peft.get_peft_model(model, hra_config, adapter_name=adapter_name) |
|
|
print('loaded a HRA model, batch = ', training_args.per_device_train_batch_size) |
|
|
elif mainCfg.run_text == 'oft': |
|
|
from peft import HRAConfig, OFTConfig |
|
|
|
|
|
oft_config = OFTConfig( |
|
|
|
|
|
oft_block_size=4*mainCfg.rotation_adapter_config.r, |
|
|
use_cayley_neumann=True, |
|
|
target_modules=["q_proj", "v_proj",], |
|
|
module_dropout=0.05, |
|
|
|
|
|
bias="none", |
|
|
) |
|
|
|
|
|
for adapter_name in mainCfg.data.adapter_names: |
|
|
model = peft.get_peft_model(model, oft_config, adapter_name=adapter_name) |
|
|
print('loaded a OFT model, batch = ', training_args.per_device_train_batch_size) |
|
|
else: |
|
|
raise KeyError('wrong model names') |
|
|
|
|
|
else: |
|
|
print("Full Parameter Fine-Tuning") |
|
|
model = model.to(DEVICE) |
|
|
|
|
|
|
|
|
model.print_trainable_parameters() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rotation_layers = filter( |
|
|
lambda p: p.requires_grad, model.parameters() |
|
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
mainCfg.model.model_name, |
|
|
model_max_length=mainCfg.model.model_max_seq_length, |
|
|
padding_side="right", |
|
|
use_fast=True, |
|
|
) |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
if tokenizer.unk_token_id is not None: |
|
|
tokenizer.pad_token_id = tokenizer.unk_token_id |
|
|
tokenizer.pad_token = tokenizer.unk_token |
|
|
print("Set PAD token to UNK token.") |
|
|
elif tokenizer.eos_token_id is not None: |
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
print("Set PAD token to EOS token.") |
|
|
|
|
|
if model is not None: |
|
|
model.config.pad_token_id = tokenizer.pad_token_id |
|
|
if model.config.pad_token_id != tokenizer.pad_token_id: |
|
|
raise ValueError("Failed to sync pad_token_id between tokenizer and model config") |
|
|
|
|
|
|
|
|
raw_datasets = load_dataset("json", data_files=mainCfg.data.path, split=mainCfg.data.dataset_split) |
|
|
|
|
|
train_dataset = raw_datasets.map( |
|
|
train_tokenize_function, |
|
|
batched=True, |
|
|
batch_size=30000, |
|
|
num_proc=32, |
|
|
remove_columns=raw_datasets.column_names, |
|
|
load_from_cache_file=True, |
|
|
desc="Running tokenizer on train dataset", |
|
|
fn_kwargs={"tokenizer": tokenizer, "query": mainCfg.data.dataset_field[0], |
|
|
"response": mainCfg.data.dataset_field[1]} |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print('- dataset size: ', len(train_dataset)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=mainCfg.model.model_max_seq_length, |
|
|
|
|
|
) |
|
|
data_module = dict(train_dataset=train_dataset, data_collator=data_collator) |
|
|
|
|
|
optimizer = optim.AdamW( |
|
|
rotation_layers, |
|
|
lr=mainCfg.trainer_args.learning_rate, |
|
|
eps=1e-8 |
|
|
) |
|
|
|
|
|
start_time = datetime.now() |
|
|
print('start time: ', start_time.strftime("%Y-%m-%d %H:%M:%S")) |
|
|
|
|
|
monitor = ExperimentMonitorCallback( |
|
|
log_file_path="./training_metrics_bs8.json", |
|
|
run_name="Experiment_BatchSize_8", |
|
|
log_interval=10 |
|
|
) |
|
|
training_args.remove_unused_columns = False |
|
|
training_args.torch_compile=False |
|
|
trainer = MyTrainer(model=model, processing_class=tokenizer, |
|
|
lamda=mainCfg.model.lambda_reg, |
|
|
optimizers=(optimizer, None), |
|
|
args=training_args, **data_module, |
|
|
callbacks=[monitor], |
|
|
) |
|
|
model.config.use_cache = False |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
end_time = datetime.now() |
|
|
print('end time: ', end_time.strftime("%Y-%m-%d %H:%M:%S"), '| duration: ', end_time - start_time) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer.save_pretrained(os.path.join(training_args.output_dir, 'ft')) |
|
|
|
|
|
trainer.save_state() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.save_pretrained(os.path.join(training_args.output_dir, 'ft2')) |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
class MyTrainer(Trainer): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: Union[PreTrainedModel, nn.Module] = None, |
|
|
args: TrainingArguments = None, |
|
|
data_collator: Optional[DataCollator] = None, |
|
|
train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None, |
|
|
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None, |
|
|
processing_class: Optional[PreTrainedTokenizerBase] = None, |
|
|
model_init: Optional[Callable[[], PreTrainedModel]] = None, |
|
|
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, |
|
|
callbacks: Optional[List[TrainerCallback]] = None, |
|
|
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), |
|
|
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, |
|
|
|
|
|
|
|
|
|
|
|
lamda: float = 1e-4 |
|
|
): |
|
|
super().__init__(model=model, args=args, data_collator=data_collator, |
|
|
train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, |
|
|
model_init=model_init, compute_metrics=compute_metrics, callbacks=callbacks, |
|
|
optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics, |
|
|
|
|
|
) |
|
|
self.lamda = lamda |
|
|
|
|
|
|
|
|
def get_train_dataloader(self): |
|
|
|
|
|
train_dataset = self.train_dataset |
|
|
sampler = self._get_train_sampler() |
|
|
|
|
|
|
|
|
batch_size = self.args.train_batch_size if hasattr(self.args, "train_batch_size") else self.args.per_device_train_batch_size |
|
|
|
|
|
|
|
|
num_workers = getattr(self.args, "dataloader_num_workers", 16) |
|
|
pin_memory = getattr(self.args, "dataloader_pin_memory", True) |
|
|
prefetch_factor = getattr(self.args, "dataloader_prefetch_factor", 2) |
|
|
persistent_workers = getattr(self.args, "dataloader_persistent_workers", True) |
|
|
|
|
|
return DataLoader( |
|
|
train_dataset, |
|
|
batch_size=batch_size, |
|
|
sampler=sampler, |
|
|
collate_fn=self.data_collator, |
|
|
drop_last=self.args.dataloader_drop_last if hasattr(self.args, "dataloader_drop_last") else False, |
|
|
num_workers=num_workers, |
|
|
pin_memory=pin_memory, |
|
|
persistent_workers=persistent_workers, |
|
|
prefetch_factor=prefetch_factor, |
|
|
worker_init_fn=default_worker_init_fn, |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |