File size: 8,723 Bytes
9c60174 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 | import re
import json
import openai
import time
import sys
import tiktoken
from random import sample
input_data = sys.argv[1]
openai_modelid = sys.argv[2]
openai.api_key = sys.argv[3]
output_path = sys.argv[4]
prompt_path = sys.argv[5]
encoding = tiktoken.encoding_for_model(openai_modelid)
q_pre = ""
qa_link = ""
MaxLen = 2048
TarLen = 512
TaskTarLen = {
"chatting_dialogsum": MaxLen,
"chatting_alpacagpt4": MaxLen,
"writing_topiocqa": TarLen // 2,
"writing_dialogsum": TarLen,
"retrieval_dialogsum": 32,
"retrieval_topiocqa": 32
}
prompts = json.load(open(prompt_path, "r"))
def normalize_model_outputs(model_text):
extracted_elements = [re.sub(r'\s+', ' ', mt.replace('"', '').replace("'", "")) for mt in re.findall(r"'[^']*'|\"[^\"]*\"|\d+", model_text)]
model_outputs = []
ti = 0
while ti + 7 < len(extracted_elements):
if extracted_elements[ti] == "topic" and extracted_elements[ti + 2] == "summary" and extracted_elements[ti + 4] == "start" and extracted_elements[ti + 6] == "end":
try:
model_outputs.append({"topic": extracted_elements[ti + 1], "summary": extracted_elements[ti + 3], "start": int(extracted_elements[ti + 5]), "end": int(extracted_elements[ti + 7])})
except:
pass
ti += 1
return model_outputs
def normalize_chatting_outputs(model_outputs):
def white_space_fix(text):
lines = text.split("\n")
result = []
for line in lines:
result.append(' '.join(line.split()))
output = '\n'.join(result)
return output
return white_space_fix(model_outputs)
def gen_model_output(input_qs, task_type):
input_qs_token_l = len(encoding.encode(input_qs)) # token num
input_qs_word_l = len(input_qs.split(" ")) # word num
qs_w_t_ratio = input_qs_word_l / input_qs_token_l
max_word_num = int((MaxLen - TarLen) * qs_w_t_ratio)
input_qs = " ".join(input_qs.split(" ")[-max_word_num:])
target_len = TaskTarLen[task_type]
messages = [{"role": "system", "content": input_qs}]
for _ in range(5):
try:
chat = openai.ChatCompletion.create(
model=openai_modelid, messages=messages, max_tokens=target_len, temperature=0.2
)
break
except:
time.sleep(5)
model_outputs = chat.choices[0].message.content
return model_outputs
def run_summary(history, memo, bot_thinking):
system_insturction = prompts["writing_dialogsum"]["system"]
task_instruction = prompts["writing_dialogsum"]["instruction"]
history_log = "\n\n```\nTask Conversation:\n" + "\n".join(["(line {}) {}".format(h_i + 1, h.replace("\n", " ")) for h_i, h in enumerate(history["Recent Dialogs"][2:])])
qs = q_pre + system_insturction.replace("LINE", str(len(history["Recent Dialogs"]) - 2)) + history_log + "\n```" + task_instruction.replace("LINE", str(len(history["Recent Dialogs"]) - 2)) + qa_link
# print("-" * 20 + "summarizing" + "-" * 20)
# print(qs)
# print("-" * 20 + "summarizing" + "-" * 20)
sum_history = gen_model_output(qs, "writing_dialogsum")
sum_history = normalize_model_outputs(sum_history)
# print("-" * 20 + "summarization" + "-" * 20)
# print(sum_history)
# print("-" * 20 + "summarization" + "-" * 20)
for s in sum_history:
memo[s["topic"]] = memo.get(s["topic"], []) + [{"summary": s["summary"], "dialogs": history["Recent Dialogs"][2:][(s["start"] - 1):s["end"]]}]
if len(sum_history) == 0:
si_0, si_1 = sample(list(range(len(history["Recent Dialogs"][2:]))), 2)
memo["NOTO"].append({"summary": "Partial dialogs about: {} or {}.".format(history["Recent Dialogs"][2:][si_0], history["Recent Dialogs"][2:][si_1]), "dialogs": history["Recent Dialogs"][2:]})
history["Recent Dialogs"] = history["Recent Dialogs"][-2:]
bot_thinking["summarization"] = {"input": qs, "output": sum_history}
return history, memo, bot_thinking
def run_retrieval(history, memo, bot_thinking):
topics = []
for k, v in memo.items():
for vv in v:
topics.append((k, vv["summary"], vv["dialogs"]))
system_insturction = prompts["retrieval"]["system"]
task_instruction = prompts["retrieval"]["instruction"]
task_case = "```\nQuery Sentence:\n" + history["User Input"][6:] + "\nTopic Options:\n" + \
"\n".join(["({}) {}".format(v_i + 1, v[0] + ". " + v[1]) for v_i, v in enumerate(topics)]) + "\n```"
qs = q_pre + system_insturction.replace("OPTION", str(len(topics))) + task_case + task_instruction.replace("OPTION", str(len(topics))) + qa_link
# print("-" * 20 + "retrieving" + "-" * 20)
# print(qs)
# print("-" * 20 + "retrieving" + "-" * 20)
outputs = gen_model_output(qs, "retrieval_dialogsum")
# print("-" * 20 + "retrieval" + "-" * 20)
# print(outputs)
# print("-" * 20 + "retrieval" + "-" * 20)
outputs = outputs.split("#")
chosen_topics = []
for output in outputs:
try:
index_ = int(output) - 1
except:
continue
if index_ < len(topics) and "NOTO" not in topics[index_]:
chosen_topics.append(topics[index_])
if len(chosen_topics) > 0:
history["Related Topics"] = [ct[0] for ct in chosen_topics]
history["Related Summaries"] = [ct[1] for ct in chosen_topics]
history["Related Dialogs"] = [" ### ".join(ct[2]) for ct in chosen_topics]
else:
history["Related Topics"] = []
history["Related Summaries"] = []
history["Related Dialogs"] = []
bot_thinking["retrieval"] = {"input": qs, "output": outputs}
return history, bot_thinking
def run_eval():
data = json.load(open(input_data, "r"))
output_data = []
for d in data:
print("=" * 20 + "start of question {}".format(d["id"]) + "=" * 20)
new_d = d
history = {
"Recent Dialogs": ["user: Hi!", "bot: Hi! How can I help you today?"],
"Related Topics": [],
"Related Summaries": [],
"Related Dialogs": [],
"User Input": "",
}
memo = {
"NOTO": [{"summary": "None of the others.", "dialogs": []}]
}
for l_i in range(len(new_d["conversations"])):
if l_i % 2 == 1:
bot_thinking = {"retrieval": "", "summarization": ""}
print("=" * 20 + "start of turn {}".format(l_i // 2 + 1) + "=" * 20)
user = "user: " + new_d["conversations"][l_i - 1]["value"]
print(user + "\n\n")
# create summary if recent dialogs exceed threshold
if len(" ### ".join(history["Recent Dialogs"]).split(" ")) > (MaxLen // 2) or len(history["Recent Dialogs"]) >= 10:
history, memo, bot_thinking = run_summary(history, memo, bot_thinking)
# retrieve most related topics for every new user input
history["User Input"] = user
if len(memo.keys()) > 1:
history, bot_thinking = run_retrieval(history, memo, bot_thinking)
# generate bot response
system_insturction = prompts["chatting"]["system"]
task_instruction = prompts["chatting"]["instruction"]
task_case = "```\nRelated Evidences:\n" + "\n".join(["({}) {}".format(r_tsd_i + 1, {
"Related Topics": history["Related Topics"][r_tsd_i],
"Related Summaries": history["Related Summaries"][r_tsd_i],
"Related Dialogs": history["Related Dialogs"][r_tsd_i]
}) for r_tsd_i in range(len(history["Related Topics"]))]) + "\n\nRecent Dialogs:\n" + \
" ### ".join([hrd.replace("\n", " ") for hrd in history["Recent Dialogs"]]) + "\n```\n\nUser Input:\n" + history["User Input"] + " ### bot: "
qs = q_pre + system_insturction + task_case + task_instruction + qa_link
outputs = gen_model_output(qs, "chatting_dialogsum")
outputs = normalize_chatting_outputs(outputs)
history["Recent Dialogs"] += [user, "bot: " + outputs]
print("bot: " + outputs + "\n")
print("=" * 20 + "end of turn {}".format(l_i // 2 + 1) + "=" * 20)
# print("\n\n\n\n")
new_d["conversations"][l_i]["thinking"] = json.dumps(bot_thinking)
new_d["conversations"][l_i]["value"] = outputs
output_data.append(new_d)
json.dump(output_data, open(output_path, "w"), indent=2)
if __name__ == "__main__":
run_eval()
|