| import json
|
| import os
|
| import torch
|
| import numpy as np
|
| from utils import *
|
| from sentence_transformers import SentenceTransformer
|
| from rapidfuzz import process
|
| from models import *
|
| import copy
|
|
|
| import warnings
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| docker_model_path = "/app/model/all-MiniLM-L6-v2"
|
| warnings.filterwarnings("ignore", category=FutureWarning, message=r".*clean_up_tokenization_spaces*")
|
|
|
| class CaseRepository:
|
| def __init__(self):
|
| try:
|
| self.embedder = SentenceTransformer(docker_model_path)
|
| except:
|
| self.embedder = SentenceTransformer(config['model']['embedding_model'])
|
| self.embedder.to(device)
|
| self.corpus = self.load_corpus()
|
| self.embedded_corpus = self.embed_corpus()
|
|
|
| def load_corpus(self):
|
| with open(os.path.join(os.path.dirname(__file__), "case_repository.json")) as file:
|
| corpus = json.load(file)
|
| return corpus
|
|
|
| def update_corpus(self):
|
| try:
|
| with open(os.path.join(os.path.dirname(__file__), "case_repository.json"), "w") as file:
|
| json.dump(self.corpus, file, indent=2)
|
| except Exception as e:
|
| print(f"Error when updating corpus: {e}")
|
|
|
| def embed_corpus(self):
|
| embedded_corpus = {}
|
| for key, content in self.corpus.items():
|
| good_index = [item['index']['embed_index'] for item in content['good']]
|
| encoded_good_index = self.embedder.encode(good_index, convert_to_tensor=True).to(device)
|
| bad_index = [item['index']['embed_index'] for item in content['bad']]
|
| encoded_bad_index = self.embedder.encode(bad_index, convert_to_tensor=True).to(device)
|
| embedded_corpus[key] = {"good": encoded_good_index, "bad": encoded_bad_index}
|
| return embedded_corpus
|
|
|
| def get_similarity_scores(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2):
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
| encoded_embed_query = self.embedder.encode(embed_index, convert_to_tensor=True).to(device)
|
| embedding_similarity_matrix = self.embedder.similarity(encoded_embed_query, self.embedded_corpus[task][case_type])
|
| embedding_similarity_scores = embedding_similarity_matrix[0].to(device)
|
|
|
|
|
| str_match_corpus = [item['index']['str_index'] for item in self.corpus[task][case_type]]
|
| str_similarity_results = process.extract(str_index, str_match_corpus, limit=len(str_match_corpus))
|
| scores_dict = {match[0]: match[1] for match in str_similarity_results}
|
| scores_in_order = [scores_dict[candidate] for candidate in str_match_corpus]
|
| str_similarity_scores = torch.tensor(scores_in_order, dtype=torch.float32).to(device)
|
|
|
|
|
| embedding_score_range = embedding_similarity_scores.max() - embedding_similarity_scores.min()
|
| str_score_range = str_similarity_scores.max() - str_similarity_scores.min()
|
| if embedding_score_range > 0:
|
| embed_norm_scores = (embedding_similarity_scores - embedding_similarity_scores.min()) / embedding_score_range
|
| else:
|
| embed_norm_scores = embedding_similarity_scores
|
| if str_score_range > 0:
|
| str_norm_scores = (str_similarity_scores - str_similarity_scores.min()) / str_score_range
|
| else:
|
| str_norm_scores = str_similarity_scores / 100
|
|
|
|
|
| combined_scores = 0.5 * embed_norm_scores + 0.5 * str_norm_scores
|
| original_combined_scores = 0.5 * embedding_similarity_scores + 0.5 * str_similarity_scores / 100
|
|
|
| scores, indices = torch.topk(combined_scores, k=min(top_k, combined_scores.size(0)))
|
| original_scores, original_indices = torch.topk(original_combined_scores, k=min(top_k, original_combined_scores.size(0)))
|
| return scores, indices, original_scores, original_indices
|
|
|
| def query_case(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2) -> list:
|
| _, indices, _, _ = self.get_similarity_scores(task, embed_index, str_index, case_type, top_k)
|
| top_matches = [self.corpus[task][case_type][idx]["content"] for idx in indices]
|
| return top_matches
|
|
|
| def update_case(self, task: TaskType, embed_index="", str_index="", content="" ,case_type=""):
|
| self.corpus[task][case_type].append({"index": {"embed_index": embed_index, "str_index": str_index}, "content": content})
|
| self.embedded_corpus[task][case_type] = torch.cat([self.embedded_corpus[task][case_type], self.embedder.encode([embed_index], convert_to_tensor=True).to(device)], dim=0)
|
| print(f"A {case_type} case updated for {task} task.")
|
|
|
| class CaseRepositoryHandler:
|
| def __init__(self, llm: BaseEngine):
|
| self.repository = CaseRepository()
|
| self.llm = llm
|
|
|
| def __get_good_case_analysis(self, instruction="", text="", result="", additional_info=""):
|
| prompt = good_case_analysis_instruction.format(
|
| instruction=instruction, text=text, result=result, additional_info=additional_info
|
| )
|
| for _ in range(3):
|
| response = self.llm.get_chat_response(prompt)
|
| response = extract_json_dict(response)
|
| if not isinstance(response, dict):
|
| return response
|
| return None
|
|
|
| def __get_bad_case_reflection(self, instruction="", text="", original_answer="", correct_answer="", additional_info=""):
|
| prompt = bad_case_reflection_instruction.format(
|
| instruction=instruction, text=text, original_answer=original_answer, correct_answer=correct_answer, additional_info=additional_info
|
| )
|
| for _ in range(3):
|
| response = self.llm.get_chat_response(prompt)
|
| response = extract_json_dict(response)
|
| if not isinstance(response, dict):
|
| return response
|
| return None
|
|
|
| def __get_index(self, data: DataPoint, case_type: str):
|
|
|
| embed_index = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
|
|
|
|
|
| if data.task == "Base":
|
| str_index = f"**Task**: {data.instruction}"
|
| else:
|
| str_index = f"{data.constraint}"
|
|
|
| if case_type == "bad":
|
| str_index += f"\n\n**Original Result**: {json.dumps(data.pred)}"
|
|
|
| return embed_index, str_index
|
|
|
| def query_good_case(self, data: DataPoint):
|
| embed_index, str_index = self.__get_index(data, "good")
|
| return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="good")
|
|
|
| def query_bad_case(self, data: DataPoint):
|
| embed_index, str_index = self.__get_index(data, "bad")
|
| return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="bad")
|
|
|
| def update_good_case(self, data: DataPoint):
|
| if data.truth == "" :
|
| print("No truth value provided.")
|
| return
|
| embed_index, str_index = self.__get_index(data, "good")
|
| _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "good", 1)
|
| original_scores = original_scores.tolist()
|
| if original_scores[0] >= 0.9:
|
| print("The similar good case is already in the corpus. Similarity Score: ", original_scores[0])
|
| return
|
| good_case_alaysis = self.__get_good_case_analysis(instruction=data.instruction, text=data.distilled_text, result=data.truth, additional_info=data.constraint)
|
| wrapped_good_case_analysis = f"**Analysis**: {good_case_alaysis}"
|
| wrapped_instruction = f"**Task**: {data.instruction}"
|
| wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
|
| wrapped_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
|
| if data.task == "Base":
|
| content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
|
| else:
|
| content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
|
| self.repository.update_case(data.task, embed_index, str_index, content, "good")
|
|
|
| def update_bad_case(self, data: DataPoint):
|
| if data.truth == "" :
|
| print("No truth value provided.")
|
| return
|
| if normalize_obj(data.pred) == normalize_obj(data.truth):
|
| return
|
| embed_index, str_index = self.__get_index(data, "bad")
|
| _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "bad", 1)
|
| original_scores = original_scores.tolist()
|
| if original_scores[0] >= 0.9:
|
| print("The similar bad case is already in the corpus. Similarity Score: ", original_scores[0])
|
| return
|
| bad_case_reflection = self.__get_bad_case_reflection(instruction=data.instruction, text=data.distilled_text, original_answer=data.pred, correct_answer=data.truth, additional_info=data.constraint)
|
| wrapped_bad_case_reflection = f"**Reflection**: {bad_case_reflection}"
|
| wrapper_original_answer = f"**Original Answer**: {json.dumps(data.pred)}"
|
| wrapper_correct_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
|
| wrapped_instruction = f"**Task**: {data.instruction}"
|
| wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
|
| if data.task == "Base":
|
| content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
|
| else:
|
| content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
|
| self.repository.update_case(data.task, embed_index, str_index, content, "bad")
|
|
|
| def update_case(self, data: DataPoint):
|
| self.update_good_case(data)
|
| self.update_bad_case(data)
|
| self.repository.update_corpus() |