import torch # import wandb import os import yaml from peft import LoraConfig, get_peft_model_state_dict from torch.utils.data import DataLoader import time from typing import List, Tuple # import prodigyopt ### import copy from dataclasses import field, dataclass, asdict from typing import Sequence, Literal, Dict import transformers from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer from transformers import Trainer from transformers.modeling_utils import * from transformers.trainer import _is_peft_model from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from transformers.data.data_collator import DataCollator from transformers.training_args import TrainingArguments from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalPrediction from torch.utils.data import Dataset, IterableDataset from datasets import load_dataset ## #from ..pipeline.flux_omini import transformer_forward, encode_images # from ...omini.rotation import RotationTuner, RotationConfig from rpeft.rotation import RotationTuner, RotationConfig from rpeft import get_peft_model, PeftModel from .config import MainConfig, convert_to_trainer_args import pyrallis from omegaconf import OmegaConf import argparse IGNORE_INDEX = -100 DEFAULT_PAD_TOKEN = "[PAD]" DEFAULT_EOS_TOKEN = "" DEFAULT_BOS_TOKEN = "" DEFAULT_UNK_TOKEN = "" PROMPT = ( "Below is an instruction that describes a task. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n### Response:" ) # parser = argparse.ArgumentParser(description='Merge Adapter to Base Model') # parser.add_argument('--base_mode', type=str) # parser.add_argument('--adapter_path', type=str) # parser.add_argument('--output_path', type=str) # args = parser.parse_args() @pyrallis.wrap() def main(mainCfg: MainConfig): print('='*120) model_name = mainCfg.model.model_name # adapter = mainCfg.trainer_args.output_dir + '/ft2' # output_path = mainCfg.trainer_args.output_dir + '/merge/' adapter = mainCfg.model.merge_adapter_path output_path = mainCfg.model.merge_output_path model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto",) tokenizer = AutoTokenizer.from_pretrained(model_name, device_map='auto') # config = PeftConfig.from_pretrained(args.adapter) model = PeftModel.from_pretrained(model, adapter) model = model.merge_and_unload() model.save_pretrained(output_path, safe_serialization=False) tokenizer.save_pretrained(output_path) # print(model) print('merge.py ends', adapter, output_path) return if __name__ == "__main__": main()