"""
Pairwise comparison handler for Bradley-Terry ranking model.
"""
import logging
import time
import random
import re
import os
import json
import base64
from pathlib import Path
from typing import Dict, List, Tuple, Union, Callable, Any
from threading import Lock
from concurrent.futures import ThreadPoolExecutor
from src.llm.base import BaseLLM
from lxml import etree
from tqdm import tqdm
from collections import defaultdict
import bytedes
import html
from functools import lru_cache
logger = logging.getLogger(__name__)
def format_comments_as_xml(comments: List[Dict]) -> str:
"""Format comments into XML string.
Args:
comments: List of comment dictionaries
Returns:
str: XML string representation of comments
"""
logger.debug(f"Building XML input for {len(comments)} comments")
xml_comments = etree.Element("comments")
for comment in comments:
xml_comment = etree.SubElement(xml_comments, "comment")
# xml_comment_id = etree.SubElement(xml_comment, "comment_id")
# xml_comment_id.text = str(comment["comment_id"])
xml_comment_text = etree.SubElement(xml_comment, "comment_text")
xml_comment_text.text = comment["text"]
# Add video context if available
video_context = {k: v for k, v in comment.items() if
k in ['video_title', 'video_tag', 'video_description', 'text_in_video'] and
comment[k] is not None and comment[k] != ''}
if len(video_context) > 0:
xml_video_context = etree.SubElement(xml_comment, "video_context")
for key, value in video_context.items():
xml_video_context_key = etree.SubElement(xml_video_context, key)
xml_video_context_key.text = str(value)
xml_comments_str = etree.tostring(xml_comments, pretty_print=True, encoding="utf-8").decode("utf-8")
logger.debug(f"Generated XML: \n{xml_comments_str}")
return xml_comments_str
def format_convs_as_xml(convs: List[Dict]) -> str:
"""Format convs into XML string.
Args:
convs: List of conv dictionaries
Returns:
str: XML string representation of convs
"""
logger.debug(f"Building XML input for {len(convs)} convs")
xml_convs = etree.Element("convs")
for conv in convs:
alias2age_map = conv['alias2age_map']
alias2age_text = ", ".join([f'user_{k} is {v}' for k, v in alias2age_map.items()])
xml_conv = etree.SubElement(xml_convs, "conv")
xml_conv_id = etree.SubElement(xml_conv, "conversation_id")
xml_conv_id.text = str(conv["conversation_id"])
xml_conv_text = etree.SubElement(xml_conv, "conv_text")
xml_conv_text.text = conv["conv_text"]
xml_conv_ageinfo = etree.SubElement(xml_conv, "conv_ageinfo")
xml_conv_ageinfo.text = alias2age_text
xml_conv_region = etree.SubElement(xml_conv, "region")
xml_conv_region.text = conv["store_region"]
if conv.get("lang_fasttext", None) is not None:
xml_conv_language = etree.SubElement(xml_conv, "language")
xml_conv_language.text = conv["lang_fasttext"]
xml_convs_str = etree.tostring(xml_convs, pretty_print=True, encoding="utf-8").decode("utf-8")
logger.debug(f"Generated XML: \n{xml_convs_str}")
return xml_convs_str
@lru_cache(maxsize=int(os.environ.get('IMAGE_CACHE_SIZE', '512')))
def image_to_base64(image_path: str) -> str:
with open(image_path, "rb") as f:
return base64.b64encode(f.read()).decode('utf-8')
def text_with_placeholders_to_content(
text: str,
image_paths: List[Union[str, None]],
placeholder_pattern: str = r'\[IMG(\d+)\]'
) -> List[Dict]:
"""
text: 原始文本,含占位符如 [IMG1]、[IMG2]等
image_paths: 图片路径列表,顺序与占位符编号一致(占位符从1开始)
placeholder_pattern: 占位符正则
返回: OpenAI API支持的content列表(text + image_url混排)
"""
content: List[Dict] = []
pos = 0
for match in re.finditer(placeholder_pattern, text):
start, end = match.span()
img_idx = int(match.group(1)) - 1 # 占位符编号从1开始
# 前面的文本
if start > pos:
sub_text = text[pos:start]
if sub_text.strip():
content.append({"type": "text", "text": sub_text})
# 插入图片
if 0 <= img_idx < len(image_paths) and image_paths[img_idx] is not None:
content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_to_base64(image_paths[img_idx])}"
}
})
pos = end
# 剩余文本
if pos < len(text):
sub_text = text[pos:]
if sub_text.strip():
content.append({"type": "text", "text": sub_text})
return content
def format_convs_as_xml_image(convs: List[Dict]) -> List[Dict]:
"""
支持多模态convs:用 [IMGn] 占位符标记图片,并返回OpenAI兼容的 content 列表。
约定(尽量兼容现有数据):
- 单图:conv['img_path'] 或 conv['img']
- 多图:conv['image_paths'] / conv['img_paths'] / conv['imgs'] / conv['images'](list[str])
模式(自动):
- **嵌入式**:若 conv['conv_text'] 中包含 [IMGk],则 k 按“该 conv 的 image_paths(从1开始)”索引,
生成XML时会把这些 [IMGk] 重写为全局 [IMGn],并在最终 content 中按位置插入图片。
- **非嵌入式**:若 conv_text 不包含占位符,则会在 XML 中为该 conv 追加 [IMGn] 节点。
"""
logger.debug(f"Building multimodal XML input for {len(convs)} convs")
xml_convs = etree.Element("convs")
image_paths_flat: List[Union[str, None]] = []
img_counter = 0 # 全局图片序号,从1开始
def _extract_image_paths(conv: Dict) -> List[str]:
paths: List[str] = []
single = conv.get('img_path', conv.get('img', None))
if isinstance(single, str) and single:
paths.append(single)
for k in ("image_paths", "img_paths", "imgs", "images"):
v = conv.get(k, None)
if isinstance(v, list):
for p in v:
if isinstance(p, str) and p:
paths.append(p)
return paths
for conv in convs:
alias2age_map = conv['alias2age_map']
alias2age_text = ", ".join([f'user_{k} is {v}' for k, v in alias2age_map.items()])
xml_conv = etree.SubElement(xml_convs, "conv")
xml_conv_id = etree.SubElement(xml_conv, "conversation_id")
xml_conv_id.text = str(conv["conversation_id"])
xml_conv_text = etree.SubElement(xml_conv, "conv_text")
conv_text = conv["conv_text"]
xml_conv_ageinfo = etree.SubElement(xml_conv, "conv_ageinfo")
xml_conv_ageinfo.text = alias2age_text
xml_conv_region = etree.SubElement(xml_conv, "region")
xml_conv_region.text = conv["store_region"]
if conv.get("lang_fasttext", None) is not None:
xml_conv_language = etree.SubElement(xml_conv, "language")
xml_conv_language.text = conv["lang_fasttext"]
img_paths = _extract_image_paths(conv)
# If conv_text has placeholders, treat them as *local indices* into img_paths and rewrite to global ids.
placeholder_pattern = r'\[IMG(\d+)\]'
if isinstance(conv_text, str) and re.search(placeholder_pattern, conv_text):
local_to_global: Dict[int, int] = {}
def _replace(m: re.Match) -> str:
nonlocal img_counter
try:
local_idx = int(m.group(1)) - 1 # local placeholders are 1-based
except Exception:
return m.group(0)
if not (0 <= local_idx < len(img_paths)):
return m.group(0)
if local_idx not in local_to_global:
img_counter += 1
local_to_global[local_idx] = img_counter
image_paths_flat.append(img_paths[local_idx])
return f'[IMG{local_to_global[local_idx]}]'
xml_conv_text.text = re.sub(placeholder_pattern, _replace, conv_text)
else:
# Non-embedded: keep text unchanged and append nodes
xml_conv_text.text = conv_text
# for p in img_paths:
# img_counter += 1
# image_paths_flat.append(p)
# xml_img = etree.SubElement(xml_conv, "image")
# xml_img.text = f'[IMG{img_counter}]'
xml_str = etree.tostring(xml_convs, pretty_print=True, encoding="utf-8").decode("utf-8")
logger.debug(f"Generated multimodal XML: \n{xml_str}")
return text_with_placeholders_to_content(xml_str, image_paths_flat)
def parse_unit_judgment(text_info: dict) -> int:
"""Parse the judgment from LLM response for units.
Args:
text_info: dict
- text: Response text from LLM
- unit: unit the LLM is processing
Returns:
int: 1 if Unit A > Unit B, -1 if Unit A < Unit B, 0 if equal
Raises:
ValueError: If judgment cannot be parsed
"""
text = text_info.get("text")
unit = text_info.get("unit", "comment")
m_res = re.search(r'<\s*result\s*>(.*?)<\s*/\s*result\s*>',
text, re.IGNORECASE | re.DOTALL)
if not m_res:
logger.error("No result block found in LLM response")
raise ValueError("No result block found")
result_body = m_res.group(1)
m_j = re.search(r'<\s*judgment\s*>(.*?)<\s*/\s*judgment\s*>',
result_body, re.IGNORECASE | re.DOTALL)
if not m_j:
logger.error("No judgment found in result block")
raise ValueError("No judgment found")
judgment_text = html.unescape(m_j.group(1).strip())
judgment_text = re.sub(r'\bFinal\s+Judgment\b', '',
judgment_text, flags=re.IGNORECASE).strip()
judgment_text = (judgment_text
.replace(":", ":")
.replace(">", ">")
.replace("<", "<")
.replace("=", "="))
# m_cmp = re.search(
# r'(?i)comment\s*([ab])\s*([<>]=?|==?)\s*comment\s*([ab])',
# judgment_text
# )
# Extracts content in between
pattern = rf'(?i){re.escape(unit)}\s*([ab])\s*([<>]=?|==?)\s*{re.escape(unit)}\s*([ab])'
m_cmp = re.search(pattern, judgment_text)
if not m_cmp:
logger.error(f"Invalid judgment format: {judgment_text!r}")
raise ValueError(f"Invalid judgment format: {judgment_text!r}")
left, op, right = m_cmp.groups()
left, right = left.upper(), right.upper()
# Extract the violative message
violative_msg_dict = {}
violative_content = re.search(r'(.*?)', text, re.DOTALL)
if violative_content:
content = violative_content.group(1).strip()
# Extract tags and their contents (m1,m2)
matches = re.findall(r'<([A-Z])>(.*?)\1>', content)
violative_msg_dict = {tag: val.strip().strip("<>[]\"").split(",") for tag, val in matches}
else:
logger.warning("No violative_messages found.")
label = -1
if op == '>':
label = 1 if left == 'A' else -1
elif op == '<':
label = -1 if left == 'A' else 1
elif op in ('=', '=='):
label = 0
else:
logger.error(f"Unsupported operator in judgment: {op!r}")
raise ValueError(f"Unsupported operator: {op!r}")
return label, violative_msg_dict
class PairwiseComparison:
"""
Handles pairwise comparisons between items using LLM.
Manages comparison results, counts, and statistics.
"""
def __init__(self,
llm: BaseLLM,
prompt_templates: Dict[str, str],
format_items: Callable[[List[Dict]], Any],
parse_judgment: Callable[[str], int],
data_dir: str,
es_index: str,
es_psm: str = "byte.es.ranking_moderation_cmt.service.my",
business: str = "comment",
max_comparisons_per_pair: int = 3,
max_workers: int = 4,
max_retries: int = 8,
initial_retry_delay: float = 2.0,
max_backups: int = 3,
detect_msg_violations: bool = True,
log_gpt_io: bool = False,
log_sample_rate: float = 0.01,
local_cache_path: Union[str, Path, None] = None,
local_cache_enabled: bool = False):
"""Initialize pairwise comparison handler.
Args:
llm: Language model for pairwise comparisons
prompt_templates: Dictionary of prompt templates
format_items: Function to format items into LLM input. Can return either str or List[Dict]
parse_judgment: Function to parse LLM response into judgment
data_dir: Directory to store state files
max_comparisons_per_pair: Maximum number of comparisons for each pair
max_workers: Maximum number of worker processes
max_retries: Maximum number of retry attempts for failed comparisons
initial_retry_delay: Initial delay in seconds before first retry
max_backups: Maximum number of backup files to keep (default: 3)
"""
self.llm = llm
self.prompt_templates = prompt_templates
self.format_items = format_items
self.parse_judgment = parse_judgment
self.max_comparisons_per_pair = max_comparisons_per_pair
self.max_workers = max_workers
self.max_retries = max_retries
self.initial_retry_delay = initial_retry_delay
self.max_backups = max_backups
self.data_dir = Path(data_dir)
self.es_index = es_index
if self.es_index == 'None':
self.es_index = None
if local_cache_path == 'None':
local_cache_path = None
self.data_dir.mkdir(parents=True, exist_ok=True)
self.local_cache_enabled = bool(local_cache_enabled or local_cache_path)
self.local_cache_path = None
if self.local_cache_enabled:
if local_cache_path is None or local_cache_path is True:
self.local_cache_path = self.data_dir / "pairwise_comparisons.jsonl"
else:
self.local_cache_path = Path(local_cache_path)
self.local_cache_path.parent.mkdir(parents=True, exist_ok=True)
self._local_cache_loaded = False
if self.es_index is not None:
self.client = bytedes.make_client(psm=es_psm, cluster="data",scheme="https",
verify_certs=False, use_ssl=True, ssl_show_warn=False, maxsize=50)
else:
self.client = None
if business.lower() == "comment":
self.process_unit = "comment"
elif business.lower() in ["dm", "dm_mm"]:
self.process_unit = "conversation"
else:
raise NotImplementedError(f"Ranking moderation for business {business} is not implemented!")
# Initialize state with thread-safe data structures
self._cache_lock = Lock()
self.comparison_results = defaultdict(list) # Store pairwise comparison results
self.detect_msg_violations = detect_msg_violations
self.log_gpt_io = log_gpt_io
self.log_sample_rate = log_sample_rate
def get_state_file_path(self, filename: str) -> Path:
"""Get the full path for a state file.
Args:
filename: Name of the state file
Returns:
Path: Full path to the state file
"""
return self.data_dir / filename
def get_pair_key(self, item_id1: str, item_id2: str) -> str:
"""Get unique key for a pair of items.
Args:
item_id1: First item ID
item_id2: Second item ID
Returns:
str: Unique key for the pair
"""
item_id1, item_id2 = str(item_id1), str(item_id2)
return f"{min(item_id1, item_id2)}_{max(item_id1, item_id2)}"
def get_ordered_pair(self, item_id1: str, item_id2: str) -> Tuple[str, str]:
"""Get ordered pair of item IDs.
Args:
item_id1: First item ID
item_id2: Second item ID
Returns:
Tuple[str, str]: Ordered pair of item IDs (min, max)
"""
return min(item_id1, item_id2), max(item_id1, item_id2)
def get_compare_result_from_es(self, pair_key: str) -> str:
# read from cache
with self._cache_lock:
if pair_key in self.comparison_results:
return self.comparison_results[pair_key]
if self.es_index is None:
if self.local_cache_enabled and not self._local_cache_loaded:
item_ids = None
try:
id1, id2 = pair_key.split("_", 1)
item_ids = [id1, id2]
except Exception:
item_ids = None
self.load_data_to_cache_from_local(item_ids, load_detail=True)
with self._cache_lock:
if pair_key in self.comparison_results:
return self.comparison_results[pair_key]
return []
trial = 0
current_delay = self.initial_retry_delay
while trial < self.max_retries:
try:
query_body = {
"query": {
"term": {
"pair_key": {
"value": pair_key
}
}
}
}
result = self.client.search(index=self.es_index, body=query_body, size=200)
compare_result = []
for hit in result['hits']['hits']:
r = {
"judgment": hit['_source']['judgment'],
"raw_response": hit['_source']['raw_response'],
"timestamp": hit['_source']["timestamp"],
"item_id_a": hit['_source']['item_id_a'],
"item_id_b": hit['_source']['item_id_b'],
'ordered_ids': self.get_ordered_pair(hit['_source']['item_id_a'], hit['_source']['item_id_b']),
'original_ids': (hit['_source']['item_id_a'], hit['_source']['item_id_b'])
}
if self.detect_msg_violations:
r["violative_msg_map_str"] = hit['_source']['violative_msg_map_str']
compare_result.append(r)
with self._cache_lock:
self.comparison_results[pair_key] = compare_result
return compare_result
except Exception as e:
trial += 1
if trial == self.max_retries:
logger.error(f"Failed to get data from es after {self.max_retries} attempts: {str(e)}")
raise
logger.warning(f"Failed to get data from es, Request failed (attempt {trial}/{self.max_retries}). "
f"Retrying in {current_delay} seconds. Error: {str(e)}")
time.sleep(current_delay)
current_delay *= 2
def load_data_to_cache_from_es(self, item_ids: List[str], load_detail=True):
"""Load all data from ES."""
# read from cache
if self.es_index is None:
return self.load_data_to_cache_from_local(item_ids, load_detail=load_detail)
def chunk_list(lst, chunk_size):
for i in range(0, len(lst), chunk_size):
yield lst[i:i + chunk_size]
comparison_results_temp = {}
scroll = "2m"
batch_size = 500
_source = ["pair_key", "judgment", "timestamp", "item_id_a", "item_id_b"]
if self.detect_msg_violations:
_source.append("violative_msg_map_str")
if load_detail:
_source.append("raw_response")
for id_batch in chunk_list(item_ids, batch_size):
query = {
"_source": _source,
"query": {
"bool": {
"should": [
{ "terms": { "item_id_a.keyword": id_batch }},
{ "terms": { "item_id_b.keyword": id_batch }}
],
"minimum_should_match": 1
}
}
}
page = self.client.search(index=self.es_index, body=query, scroll=scroll, size=5000)
sid = page["_scroll_id"]
scroll_size = len(page["hits"]["hits"])
while scroll_size > 0:
for hit in page["hits"]["hits"]:
pair_key = hit['_source']['pair_key']
r = {
"judgment": hit['_source']['judgment'],
"timestamp": hit['_source']["timestamp"],
"item_id_a": hit['_source']['item_id_a'],
"item_id_b": hit['_source']['item_id_b'],
'ordered_ids': self.get_ordered_pair(hit['_source']['item_id_a'], hit['_source']['item_id_b']),
'original_ids': (hit['_source']['item_id_a'], hit['_source']['item_id_b'])
}
if self.detect_msg_violations:
r["violative_msg_map_str"] = hit['_source']['violative_msg_map_str']
if 'raw_response' in hit['_source']:
r['raw_response'] = hit['_source']['raw_response']
exists = False
with self._cache_lock:
if pair_key in self.comparison_results:
for comparison_result in self.comparison_results[pair_key]:
if r['timestamp'] == comparison_result['timestamp']:
exists = True
break
if not exists:
if pair_key not in self.comparison_results:
self.comparison_results[pair_key] = []
self.comparison_results[pair_key].append(r)
# save a temp relate to this batch
exists = False
if pair_key in comparison_results_temp:
for comparison_result in comparison_results_temp[pair_key]:
if r['timestamp'] == comparison_result['timestamp']:
exists = True
break
if not exists:
if pair_key not in comparison_results_temp:
comparison_results_temp[pair_key] = []
comparison_results_temp[pair_key].append(r)
page = self.client.scroll(scroll_id=sid, scroll=scroll)
sid = page["_scroll_id"]
scroll_size = len(page["hits"]["hits"])
return comparison_results_temp
def load_data_to_cache_from_local(self, item_ids: Union[List[str], None], load_detail=True):
if not self.local_cache_enabled or self.local_cache_path is None:
return {}
if not self.local_cache_path.exists():
self._local_cache_loaded = True
return {}
item_id_set = None
if item_ids is not None:
item_id_set = set(str(x) for x in item_ids)
comparison_results_temp = {}
with self._cache_lock:
self._local_cache_loaded = True
with open(self.local_cache_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
doc = json.loads(line)
except Exception:
continue
if "item_id_a" not in doc or "item_id_b" not in doc:
continue
item_id_a = str(doc["item_id_a"])
item_id_b = str(doc["item_id_b"])
if item_id_set is not None and item_id_a not in item_id_set and item_id_b not in item_id_set:
continue
pair_key = doc.get("pair_key", self.get_pair_key(item_id_a, item_id_b))
r = {
"judgment": doc.get("judgment"),
"timestamp": doc.get("timestamp"),
"item_id_a": item_id_a,
"item_id_b": item_id_b,
"ordered_ids": self.get_ordered_pair(item_id_a, item_id_b),
"original_ids": (item_id_a, item_id_b)
}
if self.detect_msg_violations and "violative_msg_map_str" in doc:
r["violative_msg_map_str"] = doc["violative_msg_map_str"]
if load_detail and "raw_response" in doc:
r["raw_response"] = doc["raw_response"]
exists = False
with self._cache_lock:
if pair_key in self.comparison_results:
for comparison_result in self.comparison_results[pair_key]:
if r["timestamp"] == comparison_result.get("timestamp"):
exists = True
break
if not exists:
if pair_key not in self.comparison_results:
self.comparison_results[pair_key] = []
self.comparison_results[pair_key].append(r)
exists = False
if pair_key in comparison_results_temp:
for comparison_result in comparison_results_temp[pair_key]:
if r["timestamp"] == comparison_result.get("timestamp"):
exists = True
break
if not exists:
if pair_key not in comparison_results_temp:
comparison_results_temp[pair_key] = []
comparison_results_temp[pair_key].append(r)
return comparison_results_temp
def write_compare_result_to_es(self, pair_key: str, comparison_result):
trial = 0
current_delay = self.initial_retry_delay
while trial < self.max_retries:
try:
doc = {}
doc['pair_key'] = pair_key
doc['judgment'] = comparison_result['judgment']
doc['raw_response'] = comparison_result['raw_response']
doc['timestamp'] = comparison_result['timestamp']
doc['item_id_a'] = comparison_result['original_ids'][0]
doc['item_id_b'] = comparison_result['original_ids'][1]
if self.detect_msg_violations:
doc['violative_msg_map_str'] = comparison_result['violative_msg_map_str']
with self._cache_lock:
if pair_key in self.comparison_results:
self.comparison_results[pair_key].append(comparison_result)
else:
self.comparison_results[pair_key] = [comparison_result]
if self.local_cache_enabled and self.local_cache_path is not None:
with open(self.local_cache_path, "a", encoding="utf-8") as f:
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
if self.es_index is not None:
return self.client.index(index=self.es_index, body=doc)
else:
return
except Exception as e:
trial += 1
if trial == self.max_retries:
logger.error(f"Failed to get data from es after {self.max_retries} attempts: {str(e)}")
raise
logger.warning(f"Failed to get data from es, Request failed (attempt {trial}/{self.max_retries}). "
f"Retrying in {current_delay} seconds. Error: {str(e)}")
time.sleep(current_delay)
current_delay *= 2
def compare_single_pair(self, pair: Tuple[Dict, Dict], random_swap: bool = True, use_cache: bool = True) -> Dict[str, str]:
"""Compare a single pair of items and store the result.
Args:
pair: Tuple of (item1, item2)
random_swap: Whether to randomly swap the order of items.
If False, keep the original order.
Default is True for backward compatibility.
use_cache: Whether to use cached comparison results if available.
If False, always perform new comparison.
Default is True for backward compatibility.
"""
item1, item2 = pair
# Only swap items if random_swap is True
if random_swap and random.random() < 0.5:
item1, item2 = item2, item1
item_id1 = str(item1['item_id'])
item_id2 = str(item2['item_id'])
pair_key = self.get_pair_key(item_id1, item_id2)
ordered_id1, ordered_id2 = self.get_ordered_pair(item_id1, item_id2)
# get result from es
if use_cache:
compare_result = self.get_compare_result_from_es(pair_key)
if len(compare_result) >= self.max_comparisons_per_pair:
return random.choice(compare_result)
# Format items into LLM input
formatted_input = self.format_items([item1, item2])
messages = [
{"role": "system", "content": self.prompt_templates['system_prompt']},
{"role": "user", "content": formatted_input}
]
trial = 0
current_delay = self.initial_retry_delay
while trial < self.max_retries:
response = None
try:
response = self.llm.chat_completion(messages=messages, max_tokens=4000, temperature=0.0)
result = response['choices'][0]['message']['content']
result_info = {
"text": result,
"unit": self.process_unit
}
judgment, violative_msg_dict = self.parse_judgment(result_info)
comparison_result = {
'judgment': judgment,
'raw_response': result,
'timestamp': time.time(),
'ordered_ids': (ordered_id1, ordered_id2),
'original_ids': (item_id1, item_id2),
"item_id_a": str(item_id1),
"item_id_b": str(item_id2)
}
if self.detect_msg_violations:
# e.g. {"a": [m0,m1], "b": [m0]}
comparison_result['violative_msg_map_str'] = json.dumps(violative_msg_dict)
if self.detect_msg_violations and self.log_gpt_io and random.random()(.*?)', formatted_input_for_regex.strip(), re.DOTALL)
_convs = [html.unescape(c).strip() for c in _convs]
_age_info = re.findall(r'(.*?)', formatted_input_for_regex, re.DOTALL)
_data = {
'timestamp': time.time(),
"conversations": [{"conv_text": c, "conv_ageinfo": a} for c,a in zip(_convs, _age_info)],
"judgment": judgment,
"violative_msg_map_str": violative_msg_dict
}
f.write(json.dumps(_data, ensure_ascii=False)+"\n")
self.write_compare_result_to_es(pair_key, comparison_result)
return comparison_result
except Exception as e:
trial += 1
if str(e) == "No result block found" and trial == 2:
logger.error(f"parse error, only try 2 times")
return
if trial == self.max_retries:
logger.error(f"Failed to compare pair after {self.max_retries} attempts: {str(e)}")
raise
logger.warning(f"Request failed (attempt {trial}/{self.max_retries}). "
f"Retrying in {current_delay} seconds. Error: {str(e)}")
if response is not None:
logger.warning(f"Response: {response}")
time.sleep(current_delay)
current_delay *= 2
def compare_pairs(self, pairs: List[Tuple[Dict, Dict]], random_swap: bool = True, use_cache: bool = True, use_tqdm: bool = True) -> List[Tuple[Dict, Dict]]:
"""Compare multiple pairs of items in parallel.
Args:
pairs: List of (item1, item2) tuples to compare
random_swap: Whether to randomly swap the order of items in each pair.
If False, keep the original order.
Default is True for backward compatibility.
use_cache: Whether to use cached comparison results if available.
If False, always perform new comparison.
Default is True for backward compatibility.
Returns:
List of successfully compared pairs
"""
# Create a new process pool for each batch of comparisons
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = []
for pair in pairs:
future = executor.submit(self.compare_single_pair, pair, random_swap=random_swap, use_cache=use_cache)
futures.append((future, pair))
# Get successfully compared pairs
successful_pairs = []
for future, pair in tqdm(futures, desc="Comparing pairs", total=len(futures), leave=False, disable=not use_tqdm):
try:
result = future.result()
if result: # If comparison was successful
successful_pairs.append(pair)
except Exception as e:
logger.error(f"Error comparing pair: {str(e)}")
return successful_pairs
def understand_by_pairs(self, item: Dict, compare_results: List[Dict]):
from lxml import etree
import random
rs_high, rs_low = [], []
target_id = str(item["item_id"])
for pair_key, r in compare_results.items():
r = r[0]
a, b, j = r["item_id_a"], r["item_id_b"], r["judgment"]
if target_id == a and j >= 0:
rs_low.append(r)
elif target_id == b and j < 0:
rs_low.append(r)
elif target_id == a and j < 0:
rs_high.append(r)
elif target_id == b and j >= 0:
rs_high.append(r)
random.shuffle(rs_high)
random.shuffle(rs_low)
examples = rs_high[:2] + rs_low[:2]
root = etree.Element("analysis_input")
target_block = etree.SubElement(root, "target_conversation")
xml_target = etree.fromstring(
self.format_items([item]).encode("utf-8")
)
target_block.append(xml_target)
comps_block = etree.SubElement(root, "comparisons")
for ex in examples:
# try:
# comp_xml = etree.fromstring(ex["raw_response"].encode("utf-8"))
# wrapped = etree.SubElement(comps_block, "comparison")
# wrapped.append(comp_xml)
# except Exception:
txt = etree.SubElement(comps_block, "raw_response")
txt.text = etree.CDATA(ex["raw_response"])
xml_input = etree.tostring(root, pretty_print=True,
encoding="utf-8").decode("utf-8")
system_prompt = """You are an expert in content safety and moderation.
You will receive a set of pairwise comparison results between a **target conversation** and multiple other conversations. Your primary goal is to clearly inform users about the **exact meaning and severity** of the target conversation based solely on these comparisons. Avoid any external assumptions or information.
When crafting your analysis, follow these guidelines:
* Clearly interpret the **exact meaning** of the target conversation, thoroughly and explicitly derived from the provided comparisons. Provide sufficient details so users clearly grasp the context, intention, and potential implications of the conversation.
* Identify explicitly which conversations the target conversation is **more severe than** and which ones it is **less severe than**. Clearly quote or paraphrase each comparison conversation, and provide explicit reasons derived from comparison data.
* Ensure all explanations rely **only** on the provided comparison data—**do not speculate or infer beyond what is explicitly given**.
* Structure your response clearly and flexibly so users immediately understand:
* The precise and detailed meaning of the target conversation.
* The explicit reasons behind the conversation's classification as severe or non-severe.
* Specific examples of comparisons supporting your severity assessments, clearly quoting or paraphrasing each messages in the conversation.
Your output should be in **Markdown format**, providing clear bilingual summaries in **English and Chinese** to ensure consistency for bilingual users.
*This detailed analysis aims to provide users with a clear and explicit understanding of the exact meaning of the conversation, its severity, and concrete evidence from comparative conversation or messages.* make sure your response can achieve this
---
**Detailed Output Example:**
## 📝 Conversation Analysis
### 🔹 English
**Direct translation:**
Translate this conversation into English. If the conversation is in English, simply copy the conversation as the output.
- Maintain all special tags message ids e.g. and user_ids e.g. user_0
- Do not remove existing new line characters
- Explain any cultural references / slang that could be unsafe or harmful in line e.g. user_0: Show me your "taye" [Note: "taye" means butt/ass in Algerian slang]
- Provide cultural context (within brackets inline - e.g. [Note: xxx]) when the translation in English sounds unnatural/incoherent
- Keep the explanations brief and concise (Keep a single note to a maximum of 10 words).
**Interpreted Meaning:**
The conversation explicitly expresses [detailed interpretation explicitly based on provided comparisons, clearly explaining context, intent, and implications].
**Severity Assessment:**
* **More Severe Than:**
* "Exact or paraphrased conversation A": because \[explicit reason from comparison clearly indicating severity difference].
* "Exact or paraphrased conversation B": because \[explicit reason from comparison clearly indicating severity difference].
* **Less Severe Than:**
* "Exact or paraphrased conversation C": because \[explicit reason from comparison clearly indicating severity difference].
---
### 🔸 中文
**直接翻译:**
将此对话翻译成中文。如果对话是中文,只需复制对话作为输出即可。
- 保留所有特殊标签消息 ID,例如 和用户 ID,例如 user_0
- 不要删除现有的换行符
- 解释任何可能不安全或有害的文化指涉/俚语,例如 user_0:给我看看你的“taye”[注:“taye”在阿尔及利亚俚语中是屁股的意思]
- 如果英语翻译听起来不自然/语无伦次,请提供文化背景(在括号内,例如 [注:xxx])。
- 解释请简洁明了(每条注释最多 10 个字)。
**解释性含义**:
对话明确表达了\[基于提供的比较结果的详细解释,清晰地解释了上下文、意图和含义]。
**严重程度评估**:
* **严重程度高于**:
* “对话 A 完全一致或转述”:因为\[比较结果的明确原因清楚地表明了严重程度的差异]。
* “对话 B 完全一致或转述”:因为\[比较结果的明确原因清楚地表明了严重程度的差异]。
* **严重程度低于**:
* “对话 C 完全一致或转述”:因为\[比较结果的明确原因清楚地表明了严重程度的差异]。
"""
logger.debug(f"system_prompt: {system_prompt}")
logger.debug(f"user_prompt: {xml_input}")
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": xml_input}
]
max_retries = self.max_retries
trial = 0
current_delay = self.initial_retry_delay
response = None
while trial < max_retries:
try:
response = self.llm.chat_completion(messages=messages, max_tokens=4000, temperature=0.0)
llm_understanding = response['choices'][0]['message']['content']
logger.debug(f"llm response: {llm_understanding}")
return llm_understanding
except Exception as e:
trial += 1
if trial == self.max_retries:
logger.error(f"Failed to get data by llm api, {self.max_retries} attempts: {str(e)}")
raise
logger.warning(f"Failed to get data by llm api, Request failed (attempt {trial}/{self.max_retries}). "
f"Retrying in {current_delay} seconds. Error: {str(e)}")
if response is not None:
logger.warning(f"Response={response}")
time.sleep(current_delay)
current_delay *= 2
def get_comparison_result_by_id(self, item_id: str) -> List[Dict]:
if self.es_index is None:
if self.local_cache_enabled and not self._local_cache_loaded:
self.load_data_to_cache_from_local(None, load_detail=True)
item_id = str(item_id)
results = []
with self._cache_lock:
for pair_results in self.comparison_results.values():
for r in pair_results:
if str(r.get("item_id_a")) == item_id or str(r.get("item_id_b")) == item_id:
results.append(r)
return results
trial = 0
current_delay = self.initial_retry_delay
while trial < self.max_retries:
try:
query_body = {
"query": {
"bool": {
"should": [
{ "term": { "item_id_a.keyword": item_id }},
{ "term": { "item_id_b.keyword": item_id }}
],
"minimum_should_match": 1
}
}
}
result = self.client.search(index=self.es_index, body=query_body, size=200)
compare_result = []
for hit in result['hits']['hits']:
r = {
"judgment": hit['_source']['judgment'],
"raw_response": hit['_source']['raw_response'],
"timestamp": hit['_source']["timestamp"],
"item_id_a": hit['_source']['item_id_a'],
"item_id_b": hit['_source']['item_id_b'],
'ordered_ids': self.get_ordered_pair(hit['_source']['item_id_a'], hit['_source']['item_id_b']),
'original_ids': (hit['_source']['item_id_a'], hit['_source']['item_id_b'])
}
if self.detect_msg_violations:
r["violative_msg_map_str"] = hit['_source']['violative_msg_map_str']
compare_result.append(r)
return compare_result
except Exception as e:
trial += 1
if trial == self.max_retries:
logger.error(f"Failed to get data from es after {self.max_retries} attempts: {str(e)}")
raise
logger.warning(f"Failed to get data from es, Request failed (attempt {trial}/{self.max_retries}). "
f"Retrying in {current_delay} seconds. Error: {str(e)}")
time.sleep(current_delay)
current_delay *= 2
def get_comparison_count(self, pair_key: str) -> int:
"""Get the number of comparisons for a pair.
Args:
pair_key: Key for the pair
Returns:
int: Number of comparisons for the pair
"""
return len(self.get_comparison_results(pair_key))
def get_comparison_results(self, pair_key: str) -> List[Dict]:
"""Get comparison results for a pair.
Args:
pair_key: Key for the pair
Returns:
List[Dict]: List of comparison results for the pair
"""
return self.get_compare_result_from_es(pair_key)
def get_compare_information(self, item_ids: List[str], use_tqdm: bool = True) -> Dict[str, Any]:
"""Get comprehensive comparison information between specified items.
Args:
item_ids: List of item IDs to get comparison information for
Returns:
Dict containing:
- comparison_results: Dict mapping pair keys to their comparison results
- comparison_counts: Dict mapping item IDs to their total comparison counts
- pair_comparison_counts: Dict mapping pair keys to their comparison counts
"""
comparison_results = {}
comparison_counts = {}
pair_comparison_counts = {}
item_id_set = set(item_ids)
with self._cache_lock:
for pair_key, pair_results in tqdm(self.comparison_results.items(), desc="Scan Comparing pairs", total=len(self.comparison_results), leave=False, disable=not use_tqdm):
id1, id2 = pair_key.split('_')
if id1 in item_id_set and id2 in item_id_set:
comparison_results[pair_key] = list(pair_results)
comparison_counts[id1] = comparison_counts.get(id1, 0) + len(pair_results)
comparison_counts[id2] = comparison_counts.get(id2, 0) + len(pair_results)
pair_comparison_counts[pair_key] = len(pair_results)
return {
'comparison_results': comparison_results,
'comparison_counts': comparison_counts,
'pair_comparison_counts': pair_comparison_counts
}