""" Pairwise comparison handler for Bradley-Terry ranking model. """ import logging import time import random import re import os import json import base64 from pathlib import Path from typing import Dict, List, Tuple, Union, Callable, Any from threading import Lock from concurrent.futures import ThreadPoolExecutor from src.llm.base import BaseLLM from lxml import etree from tqdm import tqdm from collections import defaultdict import bytedes import html from functools import lru_cache logger = logging.getLogger(__name__) def format_comments_as_xml(comments: List[Dict]) -> str: """Format comments into XML string. Args: comments: List of comment dictionaries Returns: str: XML string representation of comments """ logger.debug(f"Building XML input for {len(comments)} comments") xml_comments = etree.Element("comments") for comment in comments: xml_comment = etree.SubElement(xml_comments, "comment") # xml_comment_id = etree.SubElement(xml_comment, "comment_id") # xml_comment_id.text = str(comment["comment_id"]) xml_comment_text = etree.SubElement(xml_comment, "comment_text") xml_comment_text.text = comment["text"] # Add video context if available video_context = {k: v for k, v in comment.items() if k in ['video_title', 'video_tag', 'video_description', 'text_in_video'] and comment[k] is not None and comment[k] != ''} if len(video_context) > 0: xml_video_context = etree.SubElement(xml_comment, "video_context") for key, value in video_context.items(): xml_video_context_key = etree.SubElement(xml_video_context, key) xml_video_context_key.text = str(value) xml_comments_str = etree.tostring(xml_comments, pretty_print=True, encoding="utf-8").decode("utf-8") logger.debug(f"Generated XML: \n{xml_comments_str}") return xml_comments_str def format_convs_as_xml(convs: List[Dict]) -> str: """Format convs into XML string. Args: convs: List of conv dictionaries Returns: str: XML string representation of convs """ logger.debug(f"Building XML input for {len(convs)} convs") xml_convs = etree.Element("convs") for conv in convs: alias2age_map = conv['alias2age_map'] alias2age_text = ", ".join([f'user_{k} is {v}' for k, v in alias2age_map.items()]) xml_conv = etree.SubElement(xml_convs, "conv") xml_conv_id = etree.SubElement(xml_conv, "conversation_id") xml_conv_id.text = str(conv["conversation_id"]) xml_conv_text = etree.SubElement(xml_conv, "conv_text") xml_conv_text.text = conv["conv_text"] xml_conv_ageinfo = etree.SubElement(xml_conv, "conv_ageinfo") xml_conv_ageinfo.text = alias2age_text xml_conv_region = etree.SubElement(xml_conv, "region") xml_conv_region.text = conv["store_region"] if conv.get("lang_fasttext", None) is not None: xml_conv_language = etree.SubElement(xml_conv, "language") xml_conv_language.text = conv["lang_fasttext"] xml_convs_str = etree.tostring(xml_convs, pretty_print=True, encoding="utf-8").decode("utf-8") logger.debug(f"Generated XML: \n{xml_convs_str}") return xml_convs_str @lru_cache(maxsize=int(os.environ.get('IMAGE_CACHE_SIZE', '512'))) def image_to_base64(image_path: str) -> str: with open(image_path, "rb") as f: return base64.b64encode(f.read()).decode('utf-8') def text_with_placeholders_to_content( text: str, image_paths: List[Union[str, None]], placeholder_pattern: str = r'\[IMG(\d+)\]' ) -> List[Dict]: """ text: 原始文本,含占位符如 [IMG1]、[IMG2]等 image_paths: 图片路径列表,顺序与占位符编号一致(占位符从1开始) placeholder_pattern: 占位符正则 返回: OpenAI API支持的content列表(text + image_url混排) """ content: List[Dict] = [] pos = 0 for match in re.finditer(placeholder_pattern, text): start, end = match.span() img_idx = int(match.group(1)) - 1 # 占位符编号从1开始 # 前面的文本 if start > pos: sub_text = text[pos:start] if sub_text.strip(): content.append({"type": "text", "text": sub_text}) # 插入图片 if 0 <= img_idx < len(image_paths) and image_paths[img_idx] is not None: content.append({ "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{image_to_base64(image_paths[img_idx])}" } }) pos = end # 剩余文本 if pos < len(text): sub_text = text[pos:] if sub_text.strip(): content.append({"type": "text", "text": sub_text}) return content def format_convs_as_xml_image(convs: List[Dict]) -> List[Dict]: """ 支持多模态convs:用 [IMGn] 占位符标记图片,并返回OpenAI兼容的 content 列表。 约定(尽量兼容现有数据): - 单图:conv['img_path'] 或 conv['img'] - 多图:conv['image_paths'] / conv['img_paths'] / conv['imgs'] / conv['images'](list[str]) 模式(自动): - **嵌入式**:若 conv['conv_text'] 中包含 [IMGk],则 k 按“该 conv 的 image_paths(从1开始)”索引, 生成XML时会把这些 [IMGk] 重写为全局 [IMGn],并在最终 content 中按位置插入图片。 - **非嵌入式**:若 conv_text 不包含占位符,则会在 XML 中为该 conv 追加 [IMGn] 节点。 """ logger.debug(f"Building multimodal XML input for {len(convs)} convs") xml_convs = etree.Element("convs") image_paths_flat: List[Union[str, None]] = [] img_counter = 0 # 全局图片序号,从1开始 def _extract_image_paths(conv: Dict) -> List[str]: paths: List[str] = [] single = conv.get('img_path', conv.get('img', None)) if isinstance(single, str) and single: paths.append(single) for k in ("image_paths", "img_paths", "imgs", "images"): v = conv.get(k, None) if isinstance(v, list): for p in v: if isinstance(p, str) and p: paths.append(p) return paths for conv in convs: alias2age_map = conv['alias2age_map'] alias2age_text = ", ".join([f'user_{k} is {v}' for k, v in alias2age_map.items()]) xml_conv = etree.SubElement(xml_convs, "conv") xml_conv_id = etree.SubElement(xml_conv, "conversation_id") xml_conv_id.text = str(conv["conversation_id"]) xml_conv_text = etree.SubElement(xml_conv, "conv_text") conv_text = conv["conv_text"] xml_conv_ageinfo = etree.SubElement(xml_conv, "conv_ageinfo") xml_conv_ageinfo.text = alias2age_text xml_conv_region = etree.SubElement(xml_conv, "region") xml_conv_region.text = conv["store_region"] if conv.get("lang_fasttext", None) is not None: xml_conv_language = etree.SubElement(xml_conv, "language") xml_conv_language.text = conv["lang_fasttext"] img_paths = _extract_image_paths(conv) # If conv_text has placeholders, treat them as *local indices* into img_paths and rewrite to global ids. placeholder_pattern = r'\[IMG(\d+)\]' if isinstance(conv_text, str) and re.search(placeholder_pattern, conv_text): local_to_global: Dict[int, int] = {} def _replace(m: re.Match) -> str: nonlocal img_counter try: local_idx = int(m.group(1)) - 1 # local placeholders are 1-based except Exception: return m.group(0) if not (0 <= local_idx < len(img_paths)): return m.group(0) if local_idx not in local_to_global: img_counter += 1 local_to_global[local_idx] = img_counter image_paths_flat.append(img_paths[local_idx]) return f'[IMG{local_to_global[local_idx]}]' xml_conv_text.text = re.sub(placeholder_pattern, _replace, conv_text) else: # Non-embedded: keep text unchanged and append nodes xml_conv_text.text = conv_text # for p in img_paths: # img_counter += 1 # image_paths_flat.append(p) # xml_img = etree.SubElement(xml_conv, "image") # xml_img.text = f'[IMG{img_counter}]' xml_str = etree.tostring(xml_convs, pretty_print=True, encoding="utf-8").decode("utf-8") logger.debug(f"Generated multimodal XML: \n{xml_str}") return text_with_placeholders_to_content(xml_str, image_paths_flat) def parse_unit_judgment(text_info: dict) -> int: """Parse the judgment from LLM response for units. Args: text_info: dict - text: Response text from LLM - unit: unit the LLM is processing Returns: int: 1 if Unit A > Unit B, -1 if Unit A < Unit B, 0 if equal Raises: ValueError: If judgment cannot be parsed """ text = text_info.get("text") unit = text_info.get("unit", "comment") m_res = re.search(r'<\s*result\s*>(.*?)<\s*/\s*result\s*>', text, re.IGNORECASE | re.DOTALL) if not m_res: logger.error("No result block found in LLM response") raise ValueError("No result block found") result_body = m_res.group(1) m_j = re.search(r'<\s*judgment\s*>(.*?)<\s*/\s*judgment\s*>', result_body, re.IGNORECASE | re.DOTALL) if not m_j: logger.error("No judgment found in result block") raise ValueError("No judgment found") judgment_text = html.unescape(m_j.group(1).strip()) judgment_text = re.sub(r'\bFinal\s+Judgment\b', '', judgment_text, flags=re.IGNORECASE).strip() judgment_text = (judgment_text .replace(":", ":") .replace(">", ">") .replace("<", "<") .replace("=", "=")) # m_cmp = re.search( # r'(?i)comment\s*([ab])\s*([<>]=?|==?)\s*comment\s*([ab])', # judgment_text # ) # Extracts content in between pattern = rf'(?i){re.escape(unit)}\s*([ab])\s*([<>]=?|==?)\s*{re.escape(unit)}\s*([ab])' m_cmp = re.search(pattern, judgment_text) if not m_cmp: logger.error(f"Invalid judgment format: {judgment_text!r}") raise ValueError(f"Invalid judgment format: {judgment_text!r}") left, op, right = m_cmp.groups() left, right = left.upper(), right.upper() # Extract the violative message violative_msg_dict = {} violative_content = re.search(r'(.*?)', text, re.DOTALL) if violative_content: content = violative_content.group(1).strip() # Extract tags and their contents (m1,m2) matches = re.findall(r'<([A-Z])>(.*?)', content) violative_msg_dict = {tag: val.strip().strip("<>[]\"").split(",") for tag, val in matches} else: logger.warning("No violative_messages found.") label = -1 if op == '>': label = 1 if left == 'A' else -1 elif op == '<': label = -1 if left == 'A' else 1 elif op in ('=', '=='): label = 0 else: logger.error(f"Unsupported operator in judgment: {op!r}") raise ValueError(f"Unsupported operator: {op!r}") return label, violative_msg_dict class PairwiseComparison: """ Handles pairwise comparisons between items using LLM. Manages comparison results, counts, and statistics. """ def __init__(self, llm: BaseLLM, prompt_templates: Dict[str, str], format_items: Callable[[List[Dict]], Any], parse_judgment: Callable[[str], int], data_dir: str, es_index: str, es_psm: str = "byte.es.ranking_moderation_cmt.service.my", business: str = "comment", max_comparisons_per_pair: int = 3, max_workers: int = 4, max_retries: int = 8, initial_retry_delay: float = 2.0, max_backups: int = 3, detect_msg_violations: bool = True, log_gpt_io: bool = False, log_sample_rate: float = 0.01, local_cache_path: Union[str, Path, None] = None, local_cache_enabled: bool = False): """Initialize pairwise comparison handler. Args: llm: Language model for pairwise comparisons prompt_templates: Dictionary of prompt templates format_items: Function to format items into LLM input. Can return either str or List[Dict] parse_judgment: Function to parse LLM response into judgment data_dir: Directory to store state files max_comparisons_per_pair: Maximum number of comparisons for each pair max_workers: Maximum number of worker processes max_retries: Maximum number of retry attempts for failed comparisons initial_retry_delay: Initial delay in seconds before first retry max_backups: Maximum number of backup files to keep (default: 3) """ self.llm = llm self.prompt_templates = prompt_templates self.format_items = format_items self.parse_judgment = parse_judgment self.max_comparisons_per_pair = max_comparisons_per_pair self.max_workers = max_workers self.max_retries = max_retries self.initial_retry_delay = initial_retry_delay self.max_backups = max_backups self.data_dir = Path(data_dir) self.es_index = es_index if self.es_index == 'None': self.es_index = None if local_cache_path == 'None': local_cache_path = None self.data_dir.mkdir(parents=True, exist_ok=True) self.local_cache_enabled = bool(local_cache_enabled or local_cache_path) self.local_cache_path = None if self.local_cache_enabled: if local_cache_path is None or local_cache_path is True: self.local_cache_path = self.data_dir / "pairwise_comparisons.jsonl" else: self.local_cache_path = Path(local_cache_path) self.local_cache_path.parent.mkdir(parents=True, exist_ok=True) self._local_cache_loaded = False if self.es_index is not None: self.client = bytedes.make_client(psm=es_psm, cluster="data",scheme="https", verify_certs=False, use_ssl=True, ssl_show_warn=False, maxsize=50) else: self.client = None if business.lower() == "comment": self.process_unit = "comment" elif business.lower() in ["dm", "dm_mm"]: self.process_unit = "conversation" else: raise NotImplementedError(f"Ranking moderation for business {business} is not implemented!") # Initialize state with thread-safe data structures self._cache_lock = Lock() self.comparison_results = defaultdict(list) # Store pairwise comparison results self.detect_msg_violations = detect_msg_violations self.log_gpt_io = log_gpt_io self.log_sample_rate = log_sample_rate def get_state_file_path(self, filename: str) -> Path: """Get the full path for a state file. Args: filename: Name of the state file Returns: Path: Full path to the state file """ return self.data_dir / filename def get_pair_key(self, item_id1: str, item_id2: str) -> str: """Get unique key for a pair of items. Args: item_id1: First item ID item_id2: Second item ID Returns: str: Unique key for the pair """ item_id1, item_id2 = str(item_id1), str(item_id2) return f"{min(item_id1, item_id2)}_{max(item_id1, item_id2)}" def get_ordered_pair(self, item_id1: str, item_id2: str) -> Tuple[str, str]: """Get ordered pair of item IDs. Args: item_id1: First item ID item_id2: Second item ID Returns: Tuple[str, str]: Ordered pair of item IDs (min, max) """ return min(item_id1, item_id2), max(item_id1, item_id2) def get_compare_result_from_es(self, pair_key: str) -> str: # read from cache with self._cache_lock: if pair_key in self.comparison_results: return self.comparison_results[pair_key] if self.es_index is None: if self.local_cache_enabled and not self._local_cache_loaded: item_ids = None try: id1, id2 = pair_key.split("_", 1) item_ids = [id1, id2] except Exception: item_ids = None self.load_data_to_cache_from_local(item_ids, load_detail=True) with self._cache_lock: if pair_key in self.comparison_results: return self.comparison_results[pair_key] return [] trial = 0 current_delay = self.initial_retry_delay while trial < self.max_retries: try: query_body = { "query": { "term": { "pair_key": { "value": pair_key } } } } result = self.client.search(index=self.es_index, body=query_body, size=200) compare_result = [] for hit in result['hits']['hits']: r = { "judgment": hit['_source']['judgment'], "raw_response": hit['_source']['raw_response'], "timestamp": hit['_source']["timestamp"], "item_id_a": hit['_source']['item_id_a'], "item_id_b": hit['_source']['item_id_b'], 'ordered_ids': self.get_ordered_pair(hit['_source']['item_id_a'], hit['_source']['item_id_b']), 'original_ids': (hit['_source']['item_id_a'], hit['_source']['item_id_b']) } if self.detect_msg_violations: r["violative_msg_map_str"] = hit['_source']['violative_msg_map_str'] compare_result.append(r) with self._cache_lock: self.comparison_results[pair_key] = compare_result return compare_result except Exception as e: trial += 1 if trial == self.max_retries: logger.error(f"Failed to get data from es after {self.max_retries} attempts: {str(e)}") raise logger.warning(f"Failed to get data from es, Request failed (attempt {trial}/{self.max_retries}). " f"Retrying in {current_delay} seconds. Error: {str(e)}") time.sleep(current_delay) current_delay *= 2 def load_data_to_cache_from_es(self, item_ids: List[str], load_detail=True): """Load all data from ES.""" # read from cache if self.es_index is None: return self.load_data_to_cache_from_local(item_ids, load_detail=load_detail) def chunk_list(lst, chunk_size): for i in range(0, len(lst), chunk_size): yield lst[i:i + chunk_size] comparison_results_temp = {} scroll = "2m" batch_size = 500 _source = ["pair_key", "judgment", "timestamp", "item_id_a", "item_id_b"] if self.detect_msg_violations: _source.append("violative_msg_map_str") if load_detail: _source.append("raw_response") for id_batch in chunk_list(item_ids, batch_size): query = { "_source": _source, "query": { "bool": { "should": [ { "terms": { "item_id_a.keyword": id_batch }}, { "terms": { "item_id_b.keyword": id_batch }} ], "minimum_should_match": 1 } } } page = self.client.search(index=self.es_index, body=query, scroll=scroll, size=5000) sid = page["_scroll_id"] scroll_size = len(page["hits"]["hits"]) while scroll_size > 0: for hit in page["hits"]["hits"]: pair_key = hit['_source']['pair_key'] r = { "judgment": hit['_source']['judgment'], "timestamp": hit['_source']["timestamp"], "item_id_a": hit['_source']['item_id_a'], "item_id_b": hit['_source']['item_id_b'], 'ordered_ids': self.get_ordered_pair(hit['_source']['item_id_a'], hit['_source']['item_id_b']), 'original_ids': (hit['_source']['item_id_a'], hit['_source']['item_id_b']) } if self.detect_msg_violations: r["violative_msg_map_str"] = hit['_source']['violative_msg_map_str'] if 'raw_response' in hit['_source']: r['raw_response'] = hit['_source']['raw_response'] exists = False with self._cache_lock: if pair_key in self.comparison_results: for comparison_result in self.comparison_results[pair_key]: if r['timestamp'] == comparison_result['timestamp']: exists = True break if not exists: if pair_key not in self.comparison_results: self.comparison_results[pair_key] = [] self.comparison_results[pair_key].append(r) # save a temp relate to this batch exists = False if pair_key in comparison_results_temp: for comparison_result in comparison_results_temp[pair_key]: if r['timestamp'] == comparison_result['timestamp']: exists = True break if not exists: if pair_key not in comparison_results_temp: comparison_results_temp[pair_key] = [] comparison_results_temp[pair_key].append(r) page = self.client.scroll(scroll_id=sid, scroll=scroll) sid = page["_scroll_id"] scroll_size = len(page["hits"]["hits"]) return comparison_results_temp def load_data_to_cache_from_local(self, item_ids: Union[List[str], None], load_detail=True): if not self.local_cache_enabled or self.local_cache_path is None: return {} if not self.local_cache_path.exists(): self._local_cache_loaded = True return {} item_id_set = None if item_ids is not None: item_id_set = set(str(x) for x in item_ids) comparison_results_temp = {} with self._cache_lock: self._local_cache_loaded = True with open(self.local_cache_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: doc = json.loads(line) except Exception: continue if "item_id_a" not in doc or "item_id_b" not in doc: continue item_id_a = str(doc["item_id_a"]) item_id_b = str(doc["item_id_b"]) if item_id_set is not None and item_id_a not in item_id_set and item_id_b not in item_id_set: continue pair_key = doc.get("pair_key", self.get_pair_key(item_id_a, item_id_b)) r = { "judgment": doc.get("judgment"), "timestamp": doc.get("timestamp"), "item_id_a": item_id_a, "item_id_b": item_id_b, "ordered_ids": self.get_ordered_pair(item_id_a, item_id_b), "original_ids": (item_id_a, item_id_b) } if self.detect_msg_violations and "violative_msg_map_str" in doc: r["violative_msg_map_str"] = doc["violative_msg_map_str"] if load_detail and "raw_response" in doc: r["raw_response"] = doc["raw_response"] exists = False with self._cache_lock: if pair_key in self.comparison_results: for comparison_result in self.comparison_results[pair_key]: if r["timestamp"] == comparison_result.get("timestamp"): exists = True break if not exists: if pair_key not in self.comparison_results: self.comparison_results[pair_key] = [] self.comparison_results[pair_key].append(r) exists = False if pair_key in comparison_results_temp: for comparison_result in comparison_results_temp[pair_key]: if r["timestamp"] == comparison_result.get("timestamp"): exists = True break if not exists: if pair_key not in comparison_results_temp: comparison_results_temp[pair_key] = [] comparison_results_temp[pair_key].append(r) return comparison_results_temp def write_compare_result_to_es(self, pair_key: str, comparison_result): trial = 0 current_delay = self.initial_retry_delay while trial < self.max_retries: try: doc = {} doc['pair_key'] = pair_key doc['judgment'] = comparison_result['judgment'] doc['raw_response'] = comparison_result['raw_response'] doc['timestamp'] = comparison_result['timestamp'] doc['item_id_a'] = comparison_result['original_ids'][0] doc['item_id_b'] = comparison_result['original_ids'][1] if self.detect_msg_violations: doc['violative_msg_map_str'] = comparison_result['violative_msg_map_str'] with self._cache_lock: if pair_key in self.comparison_results: self.comparison_results[pair_key].append(comparison_result) else: self.comparison_results[pair_key] = [comparison_result] if self.local_cache_enabled and self.local_cache_path is not None: with open(self.local_cache_path, "a", encoding="utf-8") as f: f.write(json.dumps(doc, ensure_ascii=False) + "\n") if self.es_index is not None: return self.client.index(index=self.es_index, body=doc) else: return except Exception as e: trial += 1 if trial == self.max_retries: logger.error(f"Failed to get data from es after {self.max_retries} attempts: {str(e)}") raise logger.warning(f"Failed to get data from es, Request failed (attempt {trial}/{self.max_retries}). " f"Retrying in {current_delay} seconds. Error: {str(e)}") time.sleep(current_delay) current_delay *= 2 def compare_single_pair(self, pair: Tuple[Dict, Dict], random_swap: bool = True, use_cache: bool = True) -> Dict[str, str]: """Compare a single pair of items and store the result. Args: pair: Tuple of (item1, item2) random_swap: Whether to randomly swap the order of items. If False, keep the original order. Default is True for backward compatibility. use_cache: Whether to use cached comparison results if available. If False, always perform new comparison. Default is True for backward compatibility. """ item1, item2 = pair # Only swap items if random_swap is True if random_swap and random.random() < 0.5: item1, item2 = item2, item1 item_id1 = str(item1['item_id']) item_id2 = str(item2['item_id']) pair_key = self.get_pair_key(item_id1, item_id2) ordered_id1, ordered_id2 = self.get_ordered_pair(item_id1, item_id2) # get result from es if use_cache: compare_result = self.get_compare_result_from_es(pair_key) if len(compare_result) >= self.max_comparisons_per_pair: return random.choice(compare_result) # Format items into LLM input formatted_input = self.format_items([item1, item2]) messages = [ {"role": "system", "content": self.prompt_templates['system_prompt']}, {"role": "user", "content": formatted_input} ] trial = 0 current_delay = self.initial_retry_delay while trial < self.max_retries: response = None try: response = self.llm.chat_completion(messages=messages, max_tokens=4000, temperature=0.0) result = response['choices'][0]['message']['content'] result_info = { "text": result, "unit": self.process_unit } judgment, violative_msg_dict = self.parse_judgment(result_info) comparison_result = { 'judgment': judgment, 'raw_response': result, 'timestamp': time.time(), 'ordered_ids': (ordered_id1, ordered_id2), 'original_ids': (item_id1, item_id2), "item_id_a": str(item_id1), "item_id_b": str(item_id2) } if self.detect_msg_violations: # e.g. {"a": [m0,m1], "b": [m0]} comparison_result['violative_msg_map_str'] = json.dumps(violative_msg_dict) if self.detect_msg_violations and self.log_gpt_io and random.random()(.*?)', formatted_input_for_regex.strip(), re.DOTALL) _convs = [html.unescape(c).strip() for c in _convs] _age_info = re.findall(r'(.*?)', formatted_input_for_regex, re.DOTALL) _data = { 'timestamp': time.time(), "conversations": [{"conv_text": c, "conv_ageinfo": a} for c,a in zip(_convs, _age_info)], "judgment": judgment, "violative_msg_map_str": violative_msg_dict } f.write(json.dumps(_data, ensure_ascii=False)+"\n") self.write_compare_result_to_es(pair_key, comparison_result) return comparison_result except Exception as e: trial += 1 if str(e) == "No result block found" and trial == 2: logger.error(f"parse error, only try 2 times") return if trial == self.max_retries: logger.error(f"Failed to compare pair after {self.max_retries} attempts: {str(e)}") raise logger.warning(f"Request failed (attempt {trial}/{self.max_retries}). " f"Retrying in {current_delay} seconds. Error: {str(e)}") if response is not None: logger.warning(f"Response: {response}") time.sleep(current_delay) current_delay *= 2 def compare_pairs(self, pairs: List[Tuple[Dict, Dict]], random_swap: bool = True, use_cache: bool = True, use_tqdm: bool = True) -> List[Tuple[Dict, Dict]]: """Compare multiple pairs of items in parallel. Args: pairs: List of (item1, item2) tuples to compare random_swap: Whether to randomly swap the order of items in each pair. If False, keep the original order. Default is True for backward compatibility. use_cache: Whether to use cached comparison results if available. If False, always perform new comparison. Default is True for backward compatibility. Returns: List of successfully compared pairs """ # Create a new process pool for each batch of comparisons with ThreadPoolExecutor(max_workers=self.max_workers) as executor: futures = [] for pair in pairs: future = executor.submit(self.compare_single_pair, pair, random_swap=random_swap, use_cache=use_cache) futures.append((future, pair)) # Get successfully compared pairs successful_pairs = [] for future, pair in tqdm(futures, desc="Comparing pairs", total=len(futures), leave=False, disable=not use_tqdm): try: result = future.result() if result: # If comparison was successful successful_pairs.append(pair) except Exception as e: logger.error(f"Error comparing pair: {str(e)}") return successful_pairs def understand_by_pairs(self, item: Dict, compare_results: List[Dict]): from lxml import etree import random rs_high, rs_low = [], [] target_id = str(item["item_id"]) for pair_key, r in compare_results.items(): r = r[0] a, b, j = r["item_id_a"], r["item_id_b"], r["judgment"] if target_id == a and j >= 0: rs_low.append(r) elif target_id == b and j < 0: rs_low.append(r) elif target_id == a and j < 0: rs_high.append(r) elif target_id == b and j >= 0: rs_high.append(r) random.shuffle(rs_high) random.shuffle(rs_low) examples = rs_high[:2] + rs_low[:2] root = etree.Element("analysis_input") target_block = etree.SubElement(root, "target_conversation") xml_target = etree.fromstring( self.format_items([item]).encode("utf-8") ) target_block.append(xml_target) comps_block = etree.SubElement(root, "comparisons") for ex in examples: # try: # comp_xml = etree.fromstring(ex["raw_response"].encode("utf-8")) # wrapped = etree.SubElement(comps_block, "comparison") # wrapped.append(comp_xml) # except Exception: txt = etree.SubElement(comps_block, "raw_response") txt.text = etree.CDATA(ex["raw_response"]) xml_input = etree.tostring(root, pretty_print=True, encoding="utf-8").decode("utf-8") system_prompt = """You are an expert in content safety and moderation. You will receive a set of pairwise comparison results between a **target conversation** and multiple other conversations. Your primary goal is to clearly inform users about the **exact meaning and severity** of the target conversation based solely on these comparisons. Avoid any external assumptions or information. When crafting your analysis, follow these guidelines: * Clearly interpret the **exact meaning** of the target conversation, thoroughly and explicitly derived from the provided comparisons. Provide sufficient details so users clearly grasp the context, intention, and potential implications of the conversation. * Identify explicitly which conversations the target conversation is **more severe than** and which ones it is **less severe than**. Clearly quote or paraphrase each comparison conversation, and provide explicit reasons derived from comparison data. * Ensure all explanations rely **only** on the provided comparison data—**do not speculate or infer beyond what is explicitly given**. * Structure your response clearly and flexibly so users immediately understand: * The precise and detailed meaning of the target conversation. * The explicit reasons behind the conversation's classification as severe or non-severe. * Specific examples of comparisons supporting your severity assessments, clearly quoting or paraphrasing each messages in the conversation. Your output should be in **Markdown format**, providing clear bilingual summaries in **English and Chinese** to ensure consistency for bilingual users. *This detailed analysis aims to provide users with a clear and explicit understanding of the exact meaning of the conversation, its severity, and concrete evidence from comparative conversation or messages.* make sure your response can achieve this --- **Detailed Output Example:** ## 📝 Conversation Analysis ### 🔹 English **Direct translation:** Translate this conversation into English. If the conversation is in English, simply copy the conversation as the output. - Maintain all special tags message ids e.g. and user_ids e.g. user_0 - Do not remove existing new line characters - Explain any cultural references / slang that could be unsafe or harmful in line e.g. user_0: Show me your "taye" [Note: "taye" means butt/ass in Algerian slang] - Provide cultural context (within brackets inline - e.g. [Note: xxx]) when the translation in English sounds unnatural/incoherent - Keep the explanations brief and concise (Keep a single note to a maximum of 10 words). **Interpreted Meaning:** The conversation explicitly expresses [detailed interpretation explicitly based on provided comparisons, clearly explaining context, intent, and implications]. **Severity Assessment:** * **More Severe Than:** * "Exact or paraphrased conversation A": because \[explicit reason from comparison clearly indicating severity difference]. * "Exact or paraphrased conversation B": because \[explicit reason from comparison clearly indicating severity difference]. * **Less Severe Than:** * "Exact or paraphrased conversation C": because \[explicit reason from comparison clearly indicating severity difference]. --- ### 🔸 中文 **直接翻译:** 将此对话翻译成中文。如果对话是中文,只需复制对话作为输出即可。 - 保留所有特殊标签消息 ID,例如 和用户 ID,例如 user_0 - 不要删除现有的换行符 - 解释任何可能不安全或有害的文化指涉/俚语,例如 user_0:给我看看你的“taye”[注:“taye”在阿尔及利亚俚语中是屁股的意思] - 如果英语翻译听起来不自然/语无伦次,请提供文化背景(在括号内,例如 [注:xxx])。 - 解释请简洁明了(每条注释最多 10 个字)。 **解释性含义**: 对话明确表达了\[基于提供的比较结果的详细解释,清晰地解释了上下文、意图和含义]。 **严重程度评估**: * **严重程度高于**: * “对话 A 完全一致或转述”:因为\[比较结果的明确原因清楚地表明了严重程度的差异]。 * “对话 B 完全一致或转述”:因为\[比较结果的明确原因清楚地表明了严重程度的差异]。 * **严重程度低于**: * “对话 C 完全一致或转述”:因为\[比较结果的明确原因清楚地表明了严重程度的差异]。 """ logger.debug(f"system_prompt: {system_prompt}") logger.debug(f"user_prompt: {xml_input}") messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": xml_input} ] max_retries = self.max_retries trial = 0 current_delay = self.initial_retry_delay response = None while trial < max_retries: try: response = self.llm.chat_completion(messages=messages, max_tokens=4000, temperature=0.0) llm_understanding = response['choices'][0]['message']['content'] logger.debug(f"llm response: {llm_understanding}") return llm_understanding except Exception as e: trial += 1 if trial == self.max_retries: logger.error(f"Failed to get data by llm api, {self.max_retries} attempts: {str(e)}") raise logger.warning(f"Failed to get data by llm api, Request failed (attempt {trial}/{self.max_retries}). " f"Retrying in {current_delay} seconds. Error: {str(e)}") if response is not None: logger.warning(f"Response={response}") time.sleep(current_delay) current_delay *= 2 def get_comparison_result_by_id(self, item_id: str) -> List[Dict]: if self.es_index is None: if self.local_cache_enabled and not self._local_cache_loaded: self.load_data_to_cache_from_local(None, load_detail=True) item_id = str(item_id) results = [] with self._cache_lock: for pair_results in self.comparison_results.values(): for r in pair_results: if str(r.get("item_id_a")) == item_id or str(r.get("item_id_b")) == item_id: results.append(r) return results trial = 0 current_delay = self.initial_retry_delay while trial < self.max_retries: try: query_body = { "query": { "bool": { "should": [ { "term": { "item_id_a.keyword": item_id }}, { "term": { "item_id_b.keyword": item_id }} ], "minimum_should_match": 1 } } } result = self.client.search(index=self.es_index, body=query_body, size=200) compare_result = [] for hit in result['hits']['hits']: r = { "judgment": hit['_source']['judgment'], "raw_response": hit['_source']['raw_response'], "timestamp": hit['_source']["timestamp"], "item_id_a": hit['_source']['item_id_a'], "item_id_b": hit['_source']['item_id_b'], 'ordered_ids': self.get_ordered_pair(hit['_source']['item_id_a'], hit['_source']['item_id_b']), 'original_ids': (hit['_source']['item_id_a'], hit['_source']['item_id_b']) } if self.detect_msg_violations: r["violative_msg_map_str"] = hit['_source']['violative_msg_map_str'] compare_result.append(r) return compare_result except Exception as e: trial += 1 if trial == self.max_retries: logger.error(f"Failed to get data from es after {self.max_retries} attempts: {str(e)}") raise logger.warning(f"Failed to get data from es, Request failed (attempt {trial}/{self.max_retries}). " f"Retrying in {current_delay} seconds. Error: {str(e)}") time.sleep(current_delay) current_delay *= 2 def get_comparison_count(self, pair_key: str) -> int: """Get the number of comparisons for a pair. Args: pair_key: Key for the pair Returns: int: Number of comparisons for the pair """ return len(self.get_comparison_results(pair_key)) def get_comparison_results(self, pair_key: str) -> List[Dict]: """Get comparison results for a pair. Args: pair_key: Key for the pair Returns: List[Dict]: List of comparison results for the pair """ return self.get_compare_result_from_es(pair_key) def get_compare_information(self, item_ids: List[str], use_tqdm: bool = True) -> Dict[str, Any]: """Get comprehensive comparison information between specified items. Args: item_ids: List of item IDs to get comparison information for Returns: Dict containing: - comparison_results: Dict mapping pair keys to their comparison results - comparison_counts: Dict mapping item IDs to their total comparison counts - pair_comparison_counts: Dict mapping pair keys to their comparison counts """ comparison_results = {} comparison_counts = {} pair_comparison_counts = {} item_id_set = set(item_ids) with self._cache_lock: for pair_key, pair_results in tqdm(self.comparison_results.items(), desc="Scan Comparing pairs", total=len(self.comparison_results), leave=False, disable=not use_tqdm): id1, id2 = pair_key.split('_') if id1 in item_id_set and id2 in item_id_set: comparison_results[pair_key] = list(pair_results) comparison_counts[id1] = comparison_counts.get(id1, 0) + len(pair_results) comparison_counts[id2] = comparison_counts.get(id2, 0) + len(pair_results) pair_comparison_counts[pair_key] = len(pair_results) return { 'comparison_results': comparison_results, 'comparison_counts': comparison_counts, 'pair_comparison_counts': pair_comparison_counts }