| """ |
| Pairwise comparison handler for Bradley-Terry ranking model. |
| """ |
|
|
| import logging |
| import time |
| import random |
| import re |
| import os |
| import json |
| import base64 |
| from pathlib import Path |
| from typing import Dict, List, Tuple, Union, Callable, Any |
| from threading import Lock |
| from concurrent.futures import ThreadPoolExecutor |
| from src.llm.base import BaseLLM |
| from lxml import etree |
| from tqdm import tqdm |
| from collections import defaultdict |
| import bytedes |
| import html |
| from functools import lru_cache |
|
|
| logger = logging.getLogger(__name__) |
|
|
| def format_comments_as_xml(comments: List[Dict]) -> str: |
| """Format comments into XML string. |
| |
| Args: |
| comments: List of comment dictionaries |
| |
| Returns: |
| str: XML string representation of comments |
| """ |
| logger.debug(f"Building XML input for {len(comments)} comments") |
| xml_comments = etree.Element("comments") |
|
|
| for comment in comments: |
| xml_comment = etree.SubElement(xml_comments, "comment") |
| |
| |
| xml_comment_text = etree.SubElement(xml_comment, "comment_text") |
| xml_comment_text.text = comment["text"] |
|
|
| |
| video_context = {k: v for k, v in comment.items() if |
| k in ['video_title', 'video_tag', 'video_description', 'text_in_video'] and |
| comment[k] is not None and comment[k] != ''} |
| if len(video_context) > 0: |
| xml_video_context = etree.SubElement(xml_comment, "video_context") |
| for key, value in video_context.items(): |
| xml_video_context_key = etree.SubElement(xml_video_context, key) |
| xml_video_context_key.text = str(value) |
|
|
| xml_comments_str = etree.tostring(xml_comments, pretty_print=True, encoding="utf-8").decode("utf-8") |
| logger.debug(f"Generated XML: \n{xml_comments_str}") |
| return xml_comments_str |
|
|
|
|
| def format_convs_as_xml(convs: List[Dict]) -> str: |
| """Format convs into XML string. |
| |
| Args: |
| convs: List of conv dictionaries |
| |
| Returns: |
| str: XML string representation of convs |
| """ |
| logger.debug(f"Building XML input for {len(convs)} convs") |
| xml_convs = etree.Element("convs") |
|
|
| for conv in convs: |
| alias2age_map = conv['alias2age_map'] |
| alias2age_text = ", ".join([f'user_{k} is {v}' for k, v in alias2age_map.items()]) |
| xml_conv = etree.SubElement(xml_convs, "conv") |
| xml_conv_id = etree.SubElement(xml_conv, "conversation_id") |
| xml_conv_id.text = str(conv["conversation_id"]) |
| xml_conv_text = etree.SubElement(xml_conv, "conv_text") |
| xml_conv_text.text = conv["conv_text"] |
| xml_conv_ageinfo = etree.SubElement(xml_conv, "conv_ageinfo") |
| xml_conv_ageinfo.text = alias2age_text |
| xml_conv_region = etree.SubElement(xml_conv, "region") |
| xml_conv_region.text = conv["store_region"] |
| if conv.get("lang_fasttext", None) is not None: |
| xml_conv_language = etree.SubElement(xml_conv, "language") |
| xml_conv_language.text = conv["lang_fasttext"] |
|
|
| xml_convs_str = etree.tostring(xml_convs, pretty_print=True, encoding="utf-8").decode("utf-8") |
| logger.debug(f"Generated XML: \n{xml_convs_str}") |
| return xml_convs_str |
|
|
|
|
| @lru_cache(maxsize=int(os.environ.get('IMAGE_CACHE_SIZE', '512'))) |
| def image_to_base64(image_path: str) -> str: |
| with open(image_path, "rb") as f: |
| return base64.b64encode(f.read()).decode('utf-8') |
|
|
|
|
| def text_with_placeholders_to_content( |
| text: str, |
| image_paths: List[Union[str, None]], |
| placeholder_pattern: str = r'\[IMG(\d+)\]' |
| ) -> List[Dict]: |
| """ |
| text: 原始文本,含占位符如 [IMG1]、[IMG2]等 |
| image_paths: 图片路径列表,顺序与占位符编号一致(占位符从1开始) |
| placeholder_pattern: 占位符正则 |
| 返回: OpenAI API支持的content列表(text + image_url混排) |
| """ |
| content: List[Dict] = [] |
| pos = 0 |
| for match in re.finditer(placeholder_pattern, text): |
| start, end = match.span() |
| img_idx = int(match.group(1)) - 1 |
| |
| if start > pos: |
| sub_text = text[pos:start] |
| if sub_text.strip(): |
| content.append({"type": "text", "text": sub_text}) |
| |
| if 0 <= img_idx < len(image_paths) and image_paths[img_idx] is not None: |
| content.append({ |
| "type": "image_url", |
| "image_url": { |
| "url": f"data:image/jpeg;base64,{image_to_base64(image_paths[img_idx])}" |
| } |
| }) |
| pos = end |
| |
| if pos < len(text): |
| sub_text = text[pos:] |
| if sub_text.strip(): |
| content.append({"type": "text", "text": sub_text}) |
| return content |
|
|
|
|
| def format_convs_as_xml_image(convs: List[Dict]) -> List[Dict]: |
| """ |
| 支持多模态convs:用 [IMGn] 占位符标记图片,并返回OpenAI兼容的 content 列表。 |
| |
| 约定(尽量兼容现有数据): |
| - 单图:conv['img_path'] 或 conv['img'] |
| - 多图:conv['image_paths'] / conv['img_paths'] / conv['imgs'] / conv['images'](list[str]) |
| |
| 模式(自动): |
| - **嵌入式**:若 conv['conv_text'] 中包含 [IMGk],则 k 按“该 conv 的 image_paths(从1开始)”索引, |
| 生成XML时会把这些 [IMGk] 重写为全局 [IMGn],并在最终 content 中按位置插入图片。 |
| - **非嵌入式**:若 conv_text 不包含占位符,则会在 XML 中为该 conv 追加 <image>[IMGn]</image> 节点。 |
| """ |
| logger.debug(f"Building multimodal XML input for {len(convs)} convs") |
| xml_convs = etree.Element("convs") |
|
|
| image_paths_flat: List[Union[str, None]] = [] |
| img_counter = 0 |
|
|
| def _extract_image_paths(conv: Dict) -> List[str]: |
| paths: List[str] = [] |
| single = conv.get('img_path', conv.get('img', None)) |
| if isinstance(single, str) and single: |
| paths.append(single) |
| for k in ("image_paths", "img_paths", "imgs", "images"): |
| v = conv.get(k, None) |
| if isinstance(v, list): |
| for p in v: |
| if isinstance(p, str) and p: |
| paths.append(p) |
| return paths |
|
|
| for conv in convs: |
| alias2age_map = conv['alias2age_map'] |
| alias2age_text = ", ".join([f'user_{k} is {v}' for k, v in alias2age_map.items()]) |
| xml_conv = etree.SubElement(xml_convs, "conv") |
| xml_conv_id = etree.SubElement(xml_conv, "conversation_id") |
| xml_conv_id.text = str(conv["conversation_id"]) |
| xml_conv_text = etree.SubElement(xml_conv, "conv_text") |
| conv_text = conv["conv_text"] |
| xml_conv_ageinfo = etree.SubElement(xml_conv, "conv_ageinfo") |
| xml_conv_ageinfo.text = alias2age_text |
| xml_conv_region = etree.SubElement(xml_conv, "region") |
| xml_conv_region.text = conv["store_region"] |
| if conv.get("lang_fasttext", None) is not None: |
| xml_conv_language = etree.SubElement(xml_conv, "language") |
| xml_conv_language.text = conv["lang_fasttext"] |
|
|
| img_paths = _extract_image_paths(conv) |
| |
| placeholder_pattern = r'\[IMG(\d+)\]' |
| if isinstance(conv_text, str) and re.search(placeholder_pattern, conv_text): |
| local_to_global: Dict[int, int] = {} |
|
|
| def _replace(m: re.Match) -> str: |
| nonlocal img_counter |
| try: |
| local_idx = int(m.group(1)) - 1 |
| except Exception: |
| return m.group(0) |
| if not (0 <= local_idx < len(img_paths)): |
| return m.group(0) |
| if local_idx not in local_to_global: |
| img_counter += 1 |
| local_to_global[local_idx] = img_counter |
| image_paths_flat.append(img_paths[local_idx]) |
| return f'[IMG{local_to_global[local_idx]}]' |
|
|
| xml_conv_text.text = re.sub(placeholder_pattern, _replace, conv_text) |
| else: |
| |
| xml_conv_text.text = conv_text |
| |
| |
| |
| |
| |
|
|
| xml_str = etree.tostring(xml_convs, pretty_print=True, encoding="utf-8").decode("utf-8") |
| logger.debug(f"Generated multimodal XML: \n{xml_str}") |
| return text_with_placeholders_to_content(xml_str, image_paths_flat) |
|
|
|
|
| def parse_unit_judgment(text_info: dict) -> int: |
| """Parse the judgment from LLM response for units. |
| |
| Args: |
| text_info: dict |
| - text: Response text from LLM |
| - unit: unit the LLM is processing |
| |
| Returns: |
| int: 1 if Unit A > Unit B, -1 if Unit A < Unit B, 0 if equal |
| |
| Raises: |
| ValueError: If judgment cannot be parsed |
| """ |
| text = text_info.get("text") |
| unit = text_info.get("unit", "comment") |
| m_res = re.search(r'<\s*result\s*>(.*?)<\s*/\s*result\s*>', |
| text, re.IGNORECASE | re.DOTALL) |
| if not m_res: |
| logger.error("No result block found in LLM response") |
| raise ValueError("No result block found") |
| result_body = m_res.group(1) |
|
|
| m_j = re.search(r'<\s*judgment\s*>(.*?)<\s*/\s*judgment\s*>', |
| result_body, re.IGNORECASE | re.DOTALL) |
| if not m_j: |
| logger.error("No judgment found in result block") |
| raise ValueError("No judgment found") |
| judgment_text = html.unescape(m_j.group(1).strip()) |
|
|
| judgment_text = re.sub(r'\bFinal\s+Judgment\b', '', |
| judgment_text, flags=re.IGNORECASE).strip() |
|
|
| judgment_text = (judgment_text |
| .replace(":", ":") |
| .replace(">", ">") |
| .replace("<", "<") |
| .replace("=", "=")) |
|
|
| |
| |
| |
| |
| |
| pattern = rf'(?i){re.escape(unit)}\s*([ab])\s*([<>]=?|==?)\s*{re.escape(unit)}\s*([ab])' |
| m_cmp = re.search(pattern, judgment_text) |
| if not m_cmp: |
| logger.error(f"Invalid judgment format: {judgment_text!r}") |
| raise ValueError(f"Invalid judgment format: {judgment_text!r}") |
|
|
| left, op, right = m_cmp.groups() |
| left, right = left.upper(), right.upper() |
|
|
| |
| violative_msg_dict = {} |
| violative_content = re.search(r'<violative_messages>(.*?)</violative_messages>', text, re.DOTALL) |
| if violative_content: |
| content = violative_content.group(1).strip() |
| |
| matches = re.findall(r'<([A-Z])>(.*?)</\1>', content) |
| violative_msg_dict = {tag: val.strip().strip("<>[]\"").split(",") for tag, val in matches} |
| else: |
| logger.warning("No violative_messages found.") |
|
|
| label = -1 |
| if op == '>': |
| label = 1 if left == 'A' else -1 |
| elif op == '<': |
| label = -1 if left == 'A' else 1 |
| elif op in ('=', '=='): |
| label = 0 |
| else: |
| logger.error(f"Unsupported operator in judgment: {op!r}") |
| raise ValueError(f"Unsupported operator: {op!r}") |
| |
| return label, violative_msg_dict |
|
|
| class PairwiseComparison: |
| """ |
| Handles pairwise comparisons between items using LLM. |
| Manages comparison results, counts, and statistics. |
| """ |
|
|
| def __init__(self, |
| llm: BaseLLM, |
| prompt_templates: Dict[str, str], |
| format_items: Callable[[List[Dict]], Any], |
| parse_judgment: Callable[[str], int], |
| data_dir: str, |
| es_index: str, |
| es_psm: str = "byte.es.ranking_moderation_cmt.service.my", |
| business: str = "comment", |
| max_comparisons_per_pair: int = 3, |
| max_workers: int = 4, |
| max_retries: int = 8, |
| initial_retry_delay: float = 2.0, |
| max_backups: int = 3, |
| detect_msg_violations: bool = True, |
| log_gpt_io: bool = False, |
| log_sample_rate: float = 0.01, |
| local_cache_path: Union[str, Path, None] = None, |
| local_cache_enabled: bool = False): |
| """Initialize pairwise comparison handler. |
| |
| Args: |
| llm: Language model for pairwise comparisons |
| prompt_templates: Dictionary of prompt templates |
| format_items: Function to format items into LLM input. Can return either str or List[Dict] |
| parse_judgment: Function to parse LLM response into judgment |
| data_dir: Directory to store state files |
| max_comparisons_per_pair: Maximum number of comparisons for each pair |
| max_workers: Maximum number of worker processes |
| max_retries: Maximum number of retry attempts for failed comparisons |
| initial_retry_delay: Initial delay in seconds before first retry |
| max_backups: Maximum number of backup files to keep (default: 3) |
| """ |
| self.llm = llm |
| self.prompt_templates = prompt_templates |
| self.format_items = format_items |
| self.parse_judgment = parse_judgment |
| self.max_comparisons_per_pair = max_comparisons_per_pair |
| self.max_workers = max_workers |
| self.max_retries = max_retries |
| self.initial_retry_delay = initial_retry_delay |
| self.max_backups = max_backups |
| self.data_dir = Path(data_dir) |
| self.es_index = es_index |
| if self.es_index == 'None': |
| self.es_index = None |
| if local_cache_path == 'None': |
| local_cache_path = None |
| self.data_dir.mkdir(parents=True, exist_ok=True) |
| self.local_cache_enabled = bool(local_cache_enabled or local_cache_path) |
| self.local_cache_path = None |
| if self.local_cache_enabled: |
| if local_cache_path is None or local_cache_path is True: |
| self.local_cache_path = self.data_dir / "pairwise_comparisons.jsonl" |
| else: |
| self.local_cache_path = Path(local_cache_path) |
| self.local_cache_path.parent.mkdir(parents=True, exist_ok=True) |
| self._local_cache_loaded = False |
|
|
| if self.es_index is not None: |
| self.client = bytedes.make_client(psm=es_psm, cluster="data",scheme="https", |
| verify_certs=False, use_ssl=True, ssl_show_warn=False, maxsize=50) |
| else: |
| self.client = None |
| |
| if business.lower() == "comment": |
| self.process_unit = "comment" |
| elif business.lower() in ["dm", "dm_mm"]: |
| self.process_unit = "conversation" |
| else: |
| raise NotImplementedError(f"Ranking moderation for business {business} is not implemented!") |
|
|
| |
| self._cache_lock = Lock() |
| self.comparison_results = defaultdict(list) |
|
|
| self.detect_msg_violations = detect_msg_violations |
| self.log_gpt_io = log_gpt_io |
| self.log_sample_rate = log_sample_rate |
| |
| def get_state_file_path(self, filename: str) -> Path: |
| """Get the full path for a state file. |
| |
| Args: |
| filename: Name of the state file |
| |
| Returns: |
| Path: Full path to the state file |
| """ |
| return self.data_dir / filename |
|
|
| def get_pair_key(self, item_id1: str, item_id2: str) -> str: |
| """Get unique key for a pair of items. |
| |
| Args: |
| item_id1: First item ID |
| item_id2: Second item ID |
| |
| Returns: |
| str: Unique key for the pair |
| """ |
| item_id1, item_id2 = str(item_id1), str(item_id2) |
| return f"{min(item_id1, item_id2)}_{max(item_id1, item_id2)}" |
|
|
| def get_ordered_pair(self, item_id1: str, item_id2: str) -> Tuple[str, str]: |
| """Get ordered pair of item IDs. |
| |
| Args: |
| item_id1: First item ID |
| item_id2: Second item ID |
| |
| Returns: |
| Tuple[str, str]: Ordered pair of item IDs (min, max) |
| """ |
| return min(item_id1, item_id2), max(item_id1, item_id2) |
|
|
| def get_compare_result_from_es(self, pair_key: str) -> str: |
| |
| with self._cache_lock: |
| if pair_key in self.comparison_results: |
| return self.comparison_results[pair_key] |
|
|
| if self.es_index is None: |
| if self.local_cache_enabled and not self._local_cache_loaded: |
| item_ids = None |
| try: |
| id1, id2 = pair_key.split("_", 1) |
| item_ids = [id1, id2] |
| except Exception: |
| item_ids = None |
| self.load_data_to_cache_from_local(item_ids, load_detail=True) |
| with self._cache_lock: |
| if pair_key in self.comparison_results: |
| return self.comparison_results[pair_key] |
| return [] |
| trial = 0 |
| current_delay = self.initial_retry_delay |
|
|
| while trial < self.max_retries: |
| try: |
| query_body = { |
| "query": { |
| "term": { |
| "pair_key": { |
| "value": pair_key |
| } |
| } |
| } |
| } |
| result = self.client.search(index=self.es_index, body=query_body, size=200) |
| compare_result = [] |
| for hit in result['hits']['hits']: |
| r = { |
| "judgment": hit['_source']['judgment'], |
| "raw_response": hit['_source']['raw_response'], |
| "timestamp": hit['_source']["timestamp"], |
| "item_id_a": hit['_source']['item_id_a'], |
| "item_id_b": hit['_source']['item_id_b'], |
| 'ordered_ids': self.get_ordered_pair(hit['_source']['item_id_a'], hit['_source']['item_id_b']), |
| 'original_ids': (hit['_source']['item_id_a'], hit['_source']['item_id_b']) |
| } |
| if self.detect_msg_violations: |
| r["violative_msg_map_str"] = hit['_source']['violative_msg_map_str'] |
|
|
| compare_result.append(r) |
| with self._cache_lock: |
| self.comparison_results[pair_key] = compare_result |
| |
| return compare_result |
| |
| except Exception as e: |
| trial += 1 |
| if trial == self.max_retries: |
| logger.error(f"Failed to get data from es after {self.max_retries} attempts: {str(e)}") |
| raise |
| logger.warning(f"Failed to get data from es, Request failed (attempt {trial}/{self.max_retries}). " |
| f"Retrying in {current_delay} seconds. Error: {str(e)}") |
| time.sleep(current_delay) |
| current_delay *= 2 |
|
|
| def load_data_to_cache_from_es(self, item_ids: List[str], load_detail=True): |
| """Load all data from ES.""" |
| |
| if self.es_index is None: |
| return self.load_data_to_cache_from_local(item_ids, load_detail=load_detail) |
| def chunk_list(lst, chunk_size): |
| for i in range(0, len(lst), chunk_size): |
| yield lst[i:i + chunk_size] |
|
|
| comparison_results_temp = {} |
|
|
| scroll = "2m" |
| batch_size = 500 |
|
|
| _source = ["pair_key", "judgment", "timestamp", "item_id_a", "item_id_b"] |
| if self.detect_msg_violations: |
| _source.append("violative_msg_map_str") |
|
|
| if load_detail: |
| _source.append("raw_response") |
| |
| for id_batch in chunk_list(item_ids, batch_size): |
| query = { |
| "_source": _source, |
| "query": { |
| "bool": { |
| "should": [ |
| { "terms": { "item_id_a.keyword": id_batch }}, |
| { "terms": { "item_id_b.keyword": id_batch }} |
| ], |
| "minimum_should_match": 1 |
| } |
| } |
| } |
| page = self.client.search(index=self.es_index, body=query, scroll=scroll, size=5000) |
| sid = page["_scroll_id"] |
| scroll_size = len(page["hits"]["hits"]) |
| while scroll_size > 0: |
| for hit in page["hits"]["hits"]: |
| pair_key = hit['_source']['pair_key'] |
| r = { |
| "judgment": hit['_source']['judgment'], |
| "timestamp": hit['_source']["timestamp"], |
| "item_id_a": hit['_source']['item_id_a'], |
| "item_id_b": hit['_source']['item_id_b'], |
| 'ordered_ids': self.get_ordered_pair(hit['_source']['item_id_a'], hit['_source']['item_id_b']), |
| 'original_ids': (hit['_source']['item_id_a'], hit['_source']['item_id_b']) |
| } |
| if self.detect_msg_violations: |
| r["violative_msg_map_str"] = hit['_source']['violative_msg_map_str'] |
|
|
| if 'raw_response' in hit['_source']: |
| r['raw_response'] = hit['_source']['raw_response'] |
| exists = False |
| with self._cache_lock: |
| if pair_key in self.comparison_results: |
| for comparison_result in self.comparison_results[pair_key]: |
| if r['timestamp'] == comparison_result['timestamp']: |
| exists = True |
| break |
| if not exists: |
| if pair_key not in self.comparison_results: |
| self.comparison_results[pair_key] = [] |
| self.comparison_results[pair_key].append(r) |
| |
| exists = False |
| if pair_key in comparison_results_temp: |
| for comparison_result in comparison_results_temp[pair_key]: |
| if r['timestamp'] == comparison_result['timestamp']: |
| exists = True |
| break |
| if not exists: |
| if pair_key not in comparison_results_temp: |
| comparison_results_temp[pair_key] = [] |
| comparison_results_temp[pair_key].append(r) |
| page = self.client.scroll(scroll_id=sid, scroll=scroll) |
| sid = page["_scroll_id"] |
| scroll_size = len(page["hits"]["hits"]) |
| return comparison_results_temp |
|
|
| def load_data_to_cache_from_local(self, item_ids: Union[List[str], None], load_detail=True): |
| if not self.local_cache_enabled or self.local_cache_path is None: |
| return {} |
| if not self.local_cache_path.exists(): |
| self._local_cache_loaded = True |
| return {} |
| item_id_set = None |
| if item_ids is not None: |
| item_id_set = set(str(x) for x in item_ids) |
| comparison_results_temp = {} |
| with self._cache_lock: |
| self._local_cache_loaded = True |
| with open(self.local_cache_path, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| try: |
| doc = json.loads(line) |
| except Exception: |
| continue |
| if "item_id_a" not in doc or "item_id_b" not in doc: |
| continue |
| item_id_a = str(doc["item_id_a"]) |
| item_id_b = str(doc["item_id_b"]) |
| if item_id_set is not None and item_id_a not in item_id_set and item_id_b not in item_id_set: |
| continue |
| pair_key = doc.get("pair_key", self.get_pair_key(item_id_a, item_id_b)) |
| r = { |
| "judgment": doc.get("judgment"), |
| "timestamp": doc.get("timestamp"), |
| "item_id_a": item_id_a, |
| "item_id_b": item_id_b, |
| "ordered_ids": self.get_ordered_pair(item_id_a, item_id_b), |
| "original_ids": (item_id_a, item_id_b) |
| } |
| if self.detect_msg_violations and "violative_msg_map_str" in doc: |
| r["violative_msg_map_str"] = doc["violative_msg_map_str"] |
| if load_detail and "raw_response" in doc: |
| r["raw_response"] = doc["raw_response"] |
| exists = False |
| with self._cache_lock: |
| if pair_key in self.comparison_results: |
| for comparison_result in self.comparison_results[pair_key]: |
| if r["timestamp"] == comparison_result.get("timestamp"): |
| exists = True |
| break |
| if not exists: |
| if pair_key not in self.comparison_results: |
| self.comparison_results[pair_key] = [] |
| self.comparison_results[pair_key].append(r) |
| exists = False |
| if pair_key in comparison_results_temp: |
| for comparison_result in comparison_results_temp[pair_key]: |
| if r["timestamp"] == comparison_result.get("timestamp"): |
| exists = True |
| break |
| if not exists: |
| if pair_key not in comparison_results_temp: |
| comparison_results_temp[pair_key] = [] |
| comparison_results_temp[pair_key].append(r) |
| return comparison_results_temp |
|
|
| def write_compare_result_to_es(self, pair_key: str, comparison_result): |
| |
| trial = 0 |
| current_delay = self.initial_retry_delay |
|
|
| while trial < self.max_retries: |
| try: |
| doc = {} |
| doc['pair_key'] = pair_key |
| doc['judgment'] = comparison_result['judgment'] |
| doc['raw_response'] = comparison_result['raw_response'] |
| doc['timestamp'] = comparison_result['timestamp'] |
| doc['item_id_a'] = comparison_result['original_ids'][0] |
| doc['item_id_b'] = comparison_result['original_ids'][1] |
| if self.detect_msg_violations: |
| doc['violative_msg_map_str'] = comparison_result['violative_msg_map_str'] |
| |
| with self._cache_lock: |
| if pair_key in self.comparison_results: |
| self.comparison_results[pair_key].append(comparison_result) |
| else: |
| self.comparison_results[pair_key] = [comparison_result] |
| if self.local_cache_enabled and self.local_cache_path is not None: |
| with open(self.local_cache_path, "a", encoding="utf-8") as f: |
| f.write(json.dumps(doc, ensure_ascii=False) + "\n") |
| if self.es_index is not None: |
| return self.client.index(index=self.es_index, body=doc) |
| else: |
| return |
| |
| except Exception as e: |
| trial += 1 |
| if trial == self.max_retries: |
| logger.error(f"Failed to get data from es after {self.max_retries} attempts: {str(e)}") |
| raise |
| logger.warning(f"Failed to get data from es, Request failed (attempt {trial}/{self.max_retries}). " |
| f"Retrying in {current_delay} seconds. Error: {str(e)}") |
| time.sleep(current_delay) |
| current_delay *= 2 |
|
|
|
|
| def compare_single_pair(self, pair: Tuple[Dict, Dict], random_swap: bool = True, use_cache: bool = True) -> Dict[str, str]: |
| """Compare a single pair of items and store the result. |
| |
| Args: |
| pair: Tuple of (item1, item2) |
| random_swap: Whether to randomly swap the order of items. |
| If False, keep the original order. |
| Default is True for backward compatibility. |
| use_cache: Whether to use cached comparison results if available. |
| If False, always perform new comparison. |
| Default is True for backward compatibility. |
| """ |
| item1, item2 = pair |
| |
| if random_swap and random.random() < 0.5: |
| item1, item2 = item2, item1 |
| |
| item_id1 = str(item1['item_id']) |
| item_id2 = str(item2['item_id']) |
| pair_key = self.get_pair_key(item_id1, item_id2) |
| ordered_id1, ordered_id2 = self.get_ordered_pair(item_id1, item_id2) |
|
|
| |
| if use_cache: |
| compare_result = self.get_compare_result_from_es(pair_key) |
| if len(compare_result) >= self.max_comparisons_per_pair: |
| return random.choice(compare_result) |
|
|
| |
| formatted_input = self.format_items([item1, item2]) |
| |
| messages = [ |
| {"role": "system", "content": self.prompt_templates['system_prompt']}, |
| {"role": "user", "content": formatted_input} |
| ] |
|
|
| trial = 0 |
| current_delay = self.initial_retry_delay |
|
|
| while trial < self.max_retries: |
| response = None |
| try: |
| response = self.llm.chat_completion(messages=messages, max_tokens=4000, temperature=0.0) |
| result = response['choices'][0]['message']['content'] |
| result_info = { |
| "text": result, |
| "unit": self.process_unit |
| } |
| judgment, violative_msg_dict = self.parse_judgment(result_info) |
|
|
| comparison_result = { |
| 'judgment': judgment, |
| 'raw_response': result, |
| 'timestamp': time.time(), |
| 'ordered_ids': (ordered_id1, ordered_id2), |
| 'original_ids': (item_id1, item_id2), |
| "item_id_a": str(item_id1), |
| "item_id_b": str(item_id2) |
| } |
| if self.detect_msg_violations: |
| |
| comparison_result['violative_msg_map_str'] = json.dumps(violative_msg_dict) |
|
|
| if self.detect_msg_violations and self.log_gpt_io and random.random()<self.log_sample_rate: |
| with self._cache_lock: |
| with open(self.data_dir / "tmp_gpt_io.jsonl", "a") as f: |
| if isinstance(formatted_input, str): |
| formatted_input_for_regex = formatted_input |
| else: |
| |
| formatted_input_for_regex = "".join( |
| c.get("text", "") |
| for c in formatted_input |
| if isinstance(c, dict) and c.get("type") == "text" |
| ) |
| _convs = re.findall(r'<conv_text>(.*?)</conv_text>', formatted_input_for_regex.strip(), re.DOTALL) |
| _convs = [html.unescape(c).strip() for c in _convs] |
| _age_info = re.findall(r'<conv_ageinfo>(.*?)</conv_ageinfo>', formatted_input_for_regex, re.DOTALL) |
| _data = { |
| 'timestamp': time.time(), |
| "conversations": [{"conv_text": c, "conv_ageinfo": a} for c,a in zip(_convs, _age_info)], |
| "judgment": judgment, |
| "violative_msg_map_str": violative_msg_dict |
| } |
| f.write(json.dumps(_data, ensure_ascii=False)+"\n") |
|
|
| self.write_compare_result_to_es(pair_key, comparison_result) |
| return comparison_result |
| except Exception as e: |
| trial += 1 |
| if str(e) == "No result block found" and trial == 2: |
| logger.error(f"parse error, only try 2 times") |
| return |
| if trial == self.max_retries: |
| logger.error(f"Failed to compare pair after {self.max_retries} attempts: {str(e)}") |
| raise |
| logger.warning(f"Request failed (attempt {trial}/{self.max_retries}). " |
| f"Retrying in {current_delay} seconds. Error: {str(e)}") |
| if response is not None: |
| logger.warning(f"Response: {response}") |
| time.sleep(current_delay) |
| current_delay *= 2 |
|
|
| def compare_pairs(self, pairs: List[Tuple[Dict, Dict]], random_swap: bool = True, use_cache: bool = True, use_tqdm: bool = True) -> List[Tuple[Dict, Dict]]: |
| """Compare multiple pairs of items in parallel. |
| |
| Args: |
| pairs: List of (item1, item2) tuples to compare |
| random_swap: Whether to randomly swap the order of items in each pair. |
| If False, keep the original order. |
| Default is True for backward compatibility. |
| use_cache: Whether to use cached comparison results if available. |
| If False, always perform new comparison. |
| Default is True for backward compatibility. |
| |
| Returns: |
| List of successfully compared pairs |
| """ |
| |
| with ThreadPoolExecutor(max_workers=self.max_workers) as executor: |
| futures = [] |
| for pair in pairs: |
| future = executor.submit(self.compare_single_pair, pair, random_swap=random_swap, use_cache=use_cache) |
| futures.append((future, pair)) |
|
|
| |
| successful_pairs = [] |
| for future, pair in tqdm(futures, desc="Comparing pairs", total=len(futures), leave=False, disable=not use_tqdm): |
| try: |
| result = future.result() |
| if result: |
| successful_pairs.append(pair) |
| except Exception as e: |
| logger.error(f"Error comparing pair: {str(e)}") |
| |
| return successful_pairs |
|
|
| def understand_by_pairs(self, item: Dict, compare_results: List[Dict]): |
| from lxml import etree |
| import random |
| rs_high, rs_low = [], [] |
| target_id = str(item["item_id"]) |
| for pair_key, r in compare_results.items(): |
| r = r[0] |
| a, b, j = r["item_id_a"], r["item_id_b"], r["judgment"] |
| if target_id == a and j >= 0: |
| rs_low.append(r) |
| elif target_id == b and j < 0: |
| rs_low.append(r) |
| elif target_id == a and j < 0: |
| rs_high.append(r) |
| elif target_id == b and j >= 0: |
| rs_high.append(r) |
| random.shuffle(rs_high) |
| random.shuffle(rs_low) |
| examples = rs_high[:2] + rs_low[:2] |
|
|
| root = etree.Element("analysis_input") |
|
|
| target_block = etree.SubElement(root, "target_conversation") |
| xml_target = etree.fromstring( |
| self.format_items([item]).encode("utf-8") |
| ) |
| target_block.append(xml_target) |
|
|
| comps_block = etree.SubElement(root, "comparisons") |
| for ex in examples: |
| |
| |
| |
| |
| |
| txt = etree.SubElement(comps_block, "raw_response") |
| txt.text = etree.CDATA(ex["raw_response"]) |
|
|
| xml_input = etree.tostring(root, pretty_print=True, |
| encoding="utf-8").decode("utf-8") |
|
|
| system_prompt = """You are an expert in content safety and moderation. |
| |
| You will receive a set of pairwise comparison results between a **target conversation** and multiple other conversations. Your primary goal is to clearly inform users about the **exact meaning and severity** of the target conversation based solely on these comparisons. Avoid any external assumptions or information. |
| |
| When crafting your analysis, follow these guidelines: |
| |
| * Clearly interpret the **exact meaning** of the target conversation, thoroughly and explicitly derived from the provided comparisons. Provide sufficient details so users clearly grasp the context, intention, and potential implications of the conversation. |
| * Identify explicitly which conversations the target conversation is **more severe than** and which ones it is **less severe than**. Clearly quote or paraphrase each comparison conversation, and provide explicit reasons derived from comparison data. |
| * Ensure all explanations rely **only** on the provided comparison data—**do not speculate or infer beyond what is explicitly given**. |
| * Structure your response clearly and flexibly so users immediately understand: |
| |
| * The precise and detailed meaning of the target conversation. |
| * The explicit reasons behind the conversation's classification as severe or non-severe. |
| * Specific examples of comparisons supporting your severity assessments, clearly quoting or paraphrasing each messages in the conversation. |
| |
| Your output should be in **Markdown format**, providing clear bilingual summaries in **English and Chinese** to ensure consistency for bilingual users. |
| |
| *This detailed analysis aims to provide users with a clear and explicit understanding of the exact meaning of the conversation, its severity, and concrete evidence from comparative conversation or messages.* make sure your response can achieve this |
| |
| |
| --- |
| |
| **Detailed Output Example:** |
| |
| ## 📝 Conversation Analysis |
| |
| ### 🔹 English |
| |
| **Direct translation:** |
| Translate this conversation into English. If the conversation is in English, simply copy the conversation as the output. |
| - Maintain all special tags message ids e.g. <m1> and user_ids e.g. user_0 |
| - Do not remove existing new line characters |
| - Explain any cultural references / slang that could be unsafe or harmful in line e.g. user_0: Show me your "taye" [Note: "taye" means butt/ass in Algerian slang] |
| - Provide cultural context (within brackets inline - e.g. [Note: xxx]) when the translation in English sounds unnatural/incoherent |
| - Keep the explanations brief and concise (Keep a single note to a maximum of 10 words). |
| |
| **Interpreted Meaning:** |
| The conversation explicitly expresses [detailed interpretation explicitly based on provided comparisons, clearly explaining context, intent, and implications]. |
| |
| **Severity Assessment:** |
| |
| * **More Severe Than:** |
| |
| * "Exact or paraphrased conversation A": because \[explicit reason from comparison clearly indicating severity difference]. |
| * "Exact or paraphrased conversation B": because \[explicit reason from comparison clearly indicating severity difference]. |
| |
| * **Less Severe Than:** |
| |
| * "Exact or paraphrased conversation C": because \[explicit reason from comparison clearly indicating severity difference]. |
| |
| --- |
| |
| ### 🔸 中文 |
| |
| **直接翻译:** |
| 将此对话翻译成中文。如果对话是中文,只需复制对话作为输出即可。 |
| - 保留所有特殊标签消息 ID,例如 <m1> 和用户 ID,例如 user_0 |
| - 不要删除现有的换行符 |
| - 解释任何可能不安全或有害的文化指涉/俚语,例如 user_0:给我看看你的“taye”[注:“taye”在阿尔及利亚俚语中是屁股的意思] |
| - 如果英语翻译听起来不自然/语无伦次,请提供文化背景(在括号内,例如 [注:xxx])。 |
| - 解释请简洁明了(每条注释最多 10 个字)。 |
| |
| **解释性含义**: |
| 对话明确表达了\[基于提供的比较结果的详细解释,清晰地解释了上下文、意图和含义]。 |
| |
| **严重程度评估**: |
| |
| * **严重程度高于**: |
| |
| * “对话 A 完全一致或转述”:因为\[比较结果的明确原因清楚地表明了严重程度的差异]。 |
| * “对话 B 完全一致或转述”:因为\[比较结果的明确原因清楚地表明了严重程度的差异]。 |
| |
| * **严重程度低于**: |
| |
| * “对话 C 完全一致或转述”:因为\[比较结果的明确原因清楚地表明了严重程度的差异]。 |
| |
| """ |
| logger.debug(f"system_prompt: {system_prompt}") |
| logger.debug(f"user_prompt: {xml_input}") |
| messages = [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": xml_input} |
| ] |
| max_retries = self.max_retries |
| trial = 0 |
| current_delay = self.initial_retry_delay |
| response = None |
| while trial < max_retries: |
| try: |
| response = self.llm.chat_completion(messages=messages, max_tokens=4000, temperature=0.0) |
| llm_understanding = response['choices'][0]['message']['content'] |
| logger.debug(f"llm response: {llm_understanding}") |
| return llm_understanding |
| except Exception as e: |
| trial += 1 |
| if trial == self.max_retries: |
| logger.error(f"Failed to get data by llm api, {self.max_retries} attempts: {str(e)}") |
| raise |
| logger.warning(f"Failed to get data by llm api, Request failed (attempt {trial}/{self.max_retries}). " |
| f"Retrying in {current_delay} seconds. Error: {str(e)}") |
| if response is not None: |
| logger.warning(f"Response={response}") |
| time.sleep(current_delay) |
| current_delay *= 2 |
|
|
| def get_comparison_result_by_id(self, item_id: str) -> List[Dict]: |
| if self.es_index is None: |
| if self.local_cache_enabled and not self._local_cache_loaded: |
| self.load_data_to_cache_from_local(None, load_detail=True) |
| item_id = str(item_id) |
| results = [] |
| with self._cache_lock: |
| for pair_results in self.comparison_results.values(): |
| for r in pair_results: |
| if str(r.get("item_id_a")) == item_id or str(r.get("item_id_b")) == item_id: |
| results.append(r) |
| return results |
| trial = 0 |
| current_delay = self.initial_retry_delay |
|
|
| while trial < self.max_retries: |
| try: |
| query_body = { |
| "query": { |
| "bool": { |
| "should": [ |
| { "term": { "item_id_a.keyword": item_id }}, |
| { "term": { "item_id_b.keyword": item_id }} |
| ], |
| "minimum_should_match": 1 |
| } |
| } |
| } |
| result = self.client.search(index=self.es_index, body=query_body, size=200) |
| compare_result = [] |
| for hit in result['hits']['hits']: |
| r = { |
| "judgment": hit['_source']['judgment'], |
| "raw_response": hit['_source']['raw_response'], |
| "timestamp": hit['_source']["timestamp"], |
| "item_id_a": hit['_source']['item_id_a'], |
| "item_id_b": hit['_source']['item_id_b'], |
| 'ordered_ids': self.get_ordered_pair(hit['_source']['item_id_a'], hit['_source']['item_id_b']), |
| 'original_ids': (hit['_source']['item_id_a'], hit['_source']['item_id_b']) |
| |
| } |
| if self.detect_msg_violations: |
| r["violative_msg_map_str"] = hit['_source']['violative_msg_map_str'] |
| |
| compare_result.append(r) |
| |
| return compare_result |
| |
| except Exception as e: |
| trial += 1 |
| if trial == self.max_retries: |
| logger.error(f"Failed to get data from es after {self.max_retries} attempts: {str(e)}") |
| raise |
| logger.warning(f"Failed to get data from es, Request failed (attempt {trial}/{self.max_retries}). " |
| f"Retrying in {current_delay} seconds. Error: {str(e)}") |
| time.sleep(current_delay) |
| current_delay *= 2 |
| |
| |
| def get_comparison_count(self, pair_key: str) -> int: |
| """Get the number of comparisons for a pair. |
| |
| Args: |
| pair_key: Key for the pair |
| |
| Returns: |
| int: Number of comparisons for the pair |
| """ |
| return len(self.get_comparison_results(pair_key)) |
|
|
| def get_comparison_results(self, pair_key: str) -> List[Dict]: |
| """Get comparison results for a pair. |
| |
| Args: |
| pair_key: Key for the pair |
| |
| Returns: |
| List[Dict]: List of comparison results for the pair |
| """ |
| return self.get_compare_result_from_es(pair_key) |
|
|
| def get_compare_information(self, item_ids: List[str], use_tqdm: bool = True) -> Dict[str, Any]: |
| """Get comprehensive comparison information between specified items. |
| |
| Args: |
| item_ids: List of item IDs to get comparison information for |
| |
| Returns: |
| Dict containing: |
| - comparison_results: Dict mapping pair keys to their comparison results |
| - comparison_counts: Dict mapping item IDs to their total comparison counts |
| - pair_comparison_counts: Dict mapping pair keys to their comparison counts |
| """ |
|
|
| comparison_results = {} |
| comparison_counts = {} |
| pair_comparison_counts = {} |
| item_id_set = set(item_ids) |
| with self._cache_lock: |
| for pair_key, pair_results in tqdm(self.comparison_results.items(), desc="Scan Comparing pairs", total=len(self.comparison_results), leave=False, disable=not use_tqdm): |
| id1, id2 = pair_key.split('_') |
| if id1 in item_id_set and id2 in item_id_set: |
| comparison_results[pair_key] = list(pair_results) |
| comparison_counts[id1] = comparison_counts.get(id1, 0) + len(pair_results) |
| comparison_counts[id2] = comparison_counts.get(id2, 0) + len(pair_results) |
| pair_comparison_counts[pair_key] = len(pair_results) |
| return { |
| 'comparison_results': comparison_results, |
| 'comparison_counts': comparison_counts, |
| 'pair_comparison_counts': pair_comparison_counts |
| } |
|
|