File size: 2,801 Bytes
6bb0065 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | import torch
# import wandb
import os
import yaml
from peft import LoraConfig, get_peft_model_state_dict
from torch.utils.data import DataLoader
import time
from typing import List, Tuple
# import prodigyopt
###
import copy
from dataclasses import field, dataclass, asdict
from typing import Sequence, Literal, Dict
import transformers
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
from transformers import Trainer
from transformers.modeling_utils import *
from transformers.trainer import _is_peft_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.data.data_collator import DataCollator
from transformers.training_args import TrainingArguments
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from torch.utils.data import Dataset, IterableDataset
from datasets import load_dataset
##
#from ..pipeline.flux_omini import transformer_forward, encode_images
# from ...omini.rotation import RotationTuner, RotationConfig
from rpeft.rotation import RotationTuner, RotationConfig
from rpeft import get_peft_model, PeftModel
from .config import MainConfig, convert_to_trainer_args
import pyrallis
from omegaconf import OmegaConf
import argparse
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "</s>"
PROMPT = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
)
# parser = argparse.ArgumentParser(description='Merge Adapter to Base Model')
# parser.add_argument('--base_mode', type=str)
# parser.add_argument('--adapter_path', type=str)
# parser.add_argument('--output_path', type=str)
# args = parser.parse_args()
@pyrallis.wrap()
def main(mainCfg: MainConfig):
print('='*120)
model_name = mainCfg.model.model_name
# adapter = mainCfg.trainer_args.output_dir + '/ft2'
# output_path = mainCfg.trainer_args.output_dir + '/merge/'
adapter = mainCfg.model.merge_adapter_path
output_path = mainCfg.model.merge_output_path
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto",)
tokenizer = AutoTokenizer.from_pretrained(model_name, device_map='auto')
# config = PeftConfig.from_pretrained(args.adapter)
model = PeftModel.from_pretrained(model, adapter)
model = model.merge_and_unload()
model.save_pretrained(output_path, safe_serialization=False)
tokenizer.save_pretrained(output_path)
# print(model)
print('merge.py ends', adapter, output_path)
return
if __name__ == "__main__":
main() |