| import argparse |
| from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, AutoModelForSeq2SeqLM |
| from optimum.bettertransformer import BetterTransformer |
| import torch |
| import os |
| import json |
| import re |
| import ray |
| import warnings |
| from random import sample |
| warnings.filterwarnings("ignore") |
|
|
| q_pre = "<s>\n" |
| qa_link = "\n" |
| MaxLen = 2048 |
| TarLen = 512 |
| TaskTarLen = { |
| "chatting_dialogsum": MaxLen, |
| "chatting_alpacagpt4": MaxLen, |
| "writing_topiocqa": TarLen // 2, |
| "writing_dialogsum": TarLen, |
| "retrieval_dialogsum": 32, |
| "retrieval_topiocqa": 32 |
| } |
|
|
| def get_gpu_memory(num_gpus): |
| """Get available memory for each GPU.""" |
| gpu_memory = [] |
| for gpu_id in range(num_gpus): |
| with torch.cuda.device(gpu_id): |
| device = torch.cuda.current_device() |
| gpu_properties = torch.cuda.get_device_properties(device) |
| total_memory = gpu_properties.total_memory / (1024**3) |
| allocated_memory = torch.cuda.memory_allocated() / (1024**3) |
| available_memory = total_memory - allocated_memory |
| gpu_memory.append(available_memory) |
| return gpu_memory |
|
|
| def normalize_model_outputs(model_text): |
| """post processing function of memo writing task""" |
| extracted_elements = [re.sub(r'\s+', ' ', mt.replace('"', '').replace("'", "")) for mt in re.findall(r"'[^']*'|\"[^\"]*\"|\d+", model_text)] |
| model_outputs = [] |
| ti = 0 |
| while ti + 7 < len(extracted_elements): |
| if extracted_elements[ti] == "topic" and extracted_elements[ti + 2] == "summary" and extracted_elements[ti + 4] == "start" and extracted_elements[ti + 6] == "end": |
| try: |
| model_outputs.append({"topic": extracted_elements[ti + 1], "summary": extracted_elements[ti + 3], "start": int(extracted_elements[ti + 5]), "end": int(extracted_elements[ti + 7])}) |
| except: |
| pass |
| ti += 1 |
| return model_outputs |
|
|
| def normalize_chatting_outputs(model_outputs): |
| """post processing function of chatting response""" |
| def white_space_fix(text): |
| lines = text.split("\n") |
| result = [] |
| for line in lines: |
| result.append(' '.join(line.split())) |
| output = '\n'.join(result) |
| return output |
| return white_space_fix(model_outputs) |
|
|
| def gen_model_output(model_path, model, tokenizer, input_qs, local_check, task_type): |
| if local_check: |
| from faker import Faker |
| fake = Faker(locale="en") |
| return fake.text(2000) |
| if "writing" in task_type: |
| eos_token_ids = [tokenizer.eos_token_id, tokenizer.encode("]", add_special_tokens=False)[0]] |
| elif "retrieval" in task_type: |
| eos_token_ids = [tokenizer.eos_token_id, tokenizer.encode("\n", add_special_tokens=False)[0], tokenizer.encode(" ", add_special_tokens=False)[0]] |
| else: |
| eos_token_ids = [tokenizer.eos_token_id] |
| if "t5" in model_path: |
| |
| input_ids = tokenizer([input_qs], max_length=MaxLen, truncation=True, add_special_tokens=False).input_ids |
| target_len = TaskTarLen[task_type] |
| repetition_penalty_value = 1.0 |
| else: |
| input_ids = tokenizer([input_qs], max_length=(MaxLen - TarLen), truncation=True, add_special_tokens=False).input_ids |
| target_len = min(len(input_ids[0]) + TaskTarLen[task_type], MaxLen) |
| repetition_penalty_value = 1.0 |
| output_ids = model.generate( |
| torch.as_tensor(input_ids).cuda(), |
| do_sample=True, |
| temperature=0.2, |
| max_length=target_len, |
| eos_token_id=eos_token_ids, |
| repetition_penalty=repetition_penalty_value |
| ) |
| if "t5" in model_path: |
| output_ids = output_ids[0] |
| else: |
| output_ids = output_ids[0][len(input_ids[0]):] |
| model_outputs = tokenizer.decode(output_ids, skip_special_tokens=True).strip() |
| return model_outputs |
|
|
| def run_summary(history, model_path, model, tokenizer, memo, local_check, bot_thinking, prompts): |
| """We assume there's no too long input from user, e.g. over 1500 tokens""" |
| system_insturction = prompts["writing_dialogsum"]["system"] |
| task_instruction = prompts["writing_dialogsum"]["instruction"] |
| history_log = "\n\n```\nTask Conversation:\n" + "\n".join(["(line {}) {}".format(h_i + 1, h.replace("\n", " ")) for h_i, h in enumerate(history["Recent Dialogs"][2:])]) |
| qs = q_pre + system_insturction.replace("LINE", str(len(history["Recent Dialogs"]) - 2)) + history_log + "\n```" + task_instruction.replace("LINE", str(len(history["Recent Dialogs"]) - 2)) + qa_link |
| print("#" * 20 + "summarizing" + "#" * 20) |
| print(qs) |
| print("#" * 20 + "summarizing" + "#" * 20) |
| sum_history = gen_model_output(model_path, model, tokenizer, qs, local_check, "writing_dialogsum") |
| sum_history = normalize_model_outputs(sum_history) |
| print("#" * 20 + "summarization" + "#" * 20) |
| print(sum_history) |
| print("#" * 20 + "summarization" + "#" * 20) |
| for s in sum_history: |
| memo[s["topic"]] = memo.get(s["topic"], []) + [{"summary": s["summary"], "dialogs": history["Recent Dialogs"][2:][(s["start"] - 1):s["end"]]}] |
| if local_check: |
| memo["test_topic{}".format(len(memo.keys()))] = [{"summary": "test_summary{}".format(len(memo.keys())), "dialogs": history["Recent Dialogs"][2:][2:4]}] |
| if len(sum_history) == 0: |
| si_0, si_1 = sample(list(range(len(history["Recent Dialogs"][2:]))), 2) |
| memo["NOTO"].append({"summary": "Partial dialogs about: {} or {}.".format(history["Recent Dialogs"][2:][si_0], history["Recent Dialogs"][2:][si_1]), "dialogs": history["Recent Dialogs"][2:]}) |
| history["Recent Dialogs"] = history["Recent Dialogs"][-2:] |
| bot_thinking["summarization"] = {"input": qs, "output": sum_history} |
| return history, memo, bot_thinking |
|
|
| def run_retrieval(history, model_path, model, tokenizer, memo, local_check, bot_thinking, prompts): |
| topics = [] |
| for k, v in memo.items(): |
| for vv in v: |
| topics.append((k, vv["summary"], vv["dialogs"])) |
| system_insturction = prompts["retrieval"]["system"] |
| task_instruction = prompts["retrieval"]["instruction"] |
| task_case = "```\nQuery Sentence:\n" + history["User Input"][6:] + "\nTopic Options:\n" + \ |
| "\n".join(["({}) {}".format(v_i + 1, v[0] + ". " + v[1]) for v_i, v in enumerate(topics)]) + "\n```" |
| qs = q_pre + system_insturction.replace("OPTION", str(len(topics))) + task_case + task_instruction.replace("OPTION", str(len(topics))) + qa_link |
| print("#" * 20 + "retrieving" + "#" * 20) |
| print(qs) |
| print("#" * 20 + "retrieving" + "#" * 20) |
| outputs = gen_model_output(model_path, model, tokenizer, qs, local_check, "retrieval_dialogsum") |
| print("#" * 20 + "retrieval" + "#" * 20) |
| print(outputs) |
| print("#" * 20 + "retrieval" + "#" * 20) |
| outputs = outputs.split("#") |
| chosen_topics = [] |
| for output in outputs: |
| try: |
| index_ = int(output) - 1 |
| except: |
| continue |
| if index_ < len(topics) and "NOTO" not in topics[index_][0]: |
| chosen_topics.append(topics[index_]) |
| if local_check: |
| chosen_topics = sample(topics, min(len(topics) - 1, 2)) |
| if len(chosen_topics) > 0: |
| history["Related Topics"] = [ct[0] for ct in chosen_topics] |
| history["Related Summaries"] = [ct[1] for ct in chosen_topics] |
| history["Related Dialogs"] = [" ### ".join(ct[2]) for ct in chosen_topics] |
| else: |
| history["Related Topics"] = [] |
| history["Related Summaries"] = [] |
| history["Related Dialogs"] = [] |
| bot_thinking["retrieval"] = {"input": qs, "output": outputs} |
| return history, bot_thinking |
|
|
| @torch.inference_mode() |
| def get_model_answers(model_path, num_gpus, local_check, load_in_8bit, ques_jsons, prompts): |
| model_path = os.path.expanduser(model_path) |
| tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, truncation_side='left') |
|
|
| if not local_check: |
| |
| |
| available_gpu_memory = get_gpu_memory(num_gpus) |
| gpu_memory_dict = {i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" for i in range(num_gpus)} |
| gpu_memory_dict["cpu"] = "0GiB" |
|
|
| if "t5" in model_path: |
| model = AutoModelForSeq2SeqLM.from_pretrained( |
| model_path, torch_dtype=torch.float16, device_map="auto", max_memory=gpu_memory_dict, load_in_8bit=load_in_8bit |
| ) |
| else: |
| model = AutoModelForCausalLM.from_pretrained( |
| model_path, torch_dtype=torch.float16, device_map="auto", max_memory=gpu_memory_dict, load_in_8bit=load_in_8bit |
| ) |
|
|
| |
| model = BetterTransformer.transform(model) |
|
|
| |
| model = model.eval() |
| else: |
| model = None |
|
|
| output_data = [] |
| for d in ques_jsons: |
| new_d = d |
|
|
| history = { |
| "Recent Dialogs": ["user: Hi!", "bot: Hi! How can I help you today?"], |
| "Related Topics": [], |
| "Related Summaries": [], |
| "Related Dialogs": [], |
| "User Input": "", |
| } |
| memo = { |
| "NOTO": [{"summary": "None of the others.", "dialogs": []}] |
| } |
|
|
| for l_i in range(len(new_d["conversations"])): |
| if l_i % 2 == 1: |
| bot_thinking = {"retrieval": "", "summarization": ""} |
| print("=" * 20 + "start of turn {}".format(l_i // 2 + 1) + "=" * 20) |
| user = "user: " + new_d["conversations"][l_i - 1]["value"] |
| print(user + "\n\n") |
|
|
| |
| if len(" ### ".join(history["Recent Dialogs"]).split(" ")) > (MaxLen // 2) or len(history["Recent Dialogs"]) >= 10: |
| history, memo, bot_thinking = run_summary(history, model_path, model, tokenizer, memo, local_check, bot_thinking, prompts) |
|
|
| |
| history["User Input"] = user |
| if len(memo.keys()) > 1: |
| history, bot_thinking = run_retrieval(history, model_path, model, tokenizer, memo, local_check, bot_thinking, prompts) |
| |
| |
| system_insturction = prompts["chatting"]["system"] |
| task_instruction = prompts["chatting"]["instruction"] |
| task_case = "```\nRelated Evidences:\n" + "\n".join(["({}) {}".format(r_tsd_i + 1, { |
| "Related Topics": history["Related Topics"][r_tsd_i], |
| "Related Summaries": history["Related Summaries"][r_tsd_i], |
| "Related Dialogs": history["Related Dialogs"][r_tsd_i] |
| }) for r_tsd_i in range(len(history["Related Topics"]))]) + "\n\nRecent Dialogs:\n" + \ |
| " ### ".join([hrd.replace("\n", " ") for hrd in history["Recent Dialogs"]]) + "\n```\n\nUser Input:\n" + history["User Input"] + " ### bot: " |
| qs = q_pre + system_insturction + task_case + task_instruction + qa_link |
| outputs = gen_model_output(model_path, model, tokenizer, qs, local_check, "chatting_dialogsum") |
| outputs = normalize_chatting_outputs(outputs) |
| history["Recent Dialogs"] += [user, "bot: " + outputs] |
| print("bot: " + outputs + "\n") |
| print("=" * 20 + "end of turn {}".format(l_i // 2 + 1) + "=" * 20) |
| print("\n\n\n\n") |
| new_d["conversations"][l_i]["thinking"] = json.dumps(bot_thinking) |
| new_d["conversations"][l_i]["value"] = outputs |
|
|
| output_data.append(d) |
| return output_data |
|
|
| def run_eval(model_path, num_gpus, local_check, load_in_8bit, question_file, ray_num_gpus, answer_file, prompt_path): |
| assert num_gpus % ray_num_gpus == 0 |
| prompts = json.load(open(prompt_path, "r")) |
|
|
| |
| ques_jsons = json.load(open(os.path.expanduser(question_file), "r")) |
|
|
| chunk_size = len(ques_jsons) // (num_gpus // ray_num_gpus) |
| ans_handles = [] |
| for i in range(0, len(ques_jsons), chunk_size): |
| get_answers_func = ray.remote(num_gpus=ray_num_gpus)( |
| get_model_answers |
| ).remote |
| ans_handles.append( |
| get_answers_func( |
| model_path, ray_num_gpus, local_check, load_in_8bit, ques_jsons[i: i + chunk_size], prompts |
| ) |
| ) |
|
|
| ans_jsons = [] |
| for ans_handle in ans_handles: |
| ans_jsons.extend(ray.get(ans_handle)) |
|
|
| json.dump(ans_jsons, open(os.path.expanduser(answer_file), "w"), indent=2) |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model-path", type=str, required=True) |
| parser.add_argument("--num-gpus", type=int, required=True) |
| parser.add_argument("--ray-num-gpus", type=int, required=True) |
| parser.add_argument("--local-check", action="store_true", help="use faker to generate fake resposne, used for pipeline prompt checking") |
| parser.add_argument("--load-in-8bit", action="store_true") |
| parser.add_argument("--question-file", type=str, required=True) |
| parser.add_argument("--answer-file", type=str, required=True) |
| parser.add_argument("--prompt-path", type=str, required=True) |
| args = parser.parse_args() |
|
|
| run_eval( |
| args.model_path, |
| args.num_gpus, |
| args.local_check, |
| args.load_in_8bit, |
| args.question_file, |
| args.ray_num_gpus, |
| args.answer_file, |
| args.prompt_path |
| ) |
|
|