Wendy-Fly commited on
Commit
91c518d
·
verified ·
1 Parent(s): b8df358

Upload pairwise_comparison.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pairwise_comparison.py +1060 -0
pairwise_comparison.py ADDED
@@ -0,0 +1,1060 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pairwise comparison handler for Bradley-Terry ranking model.
3
+ """
4
+
5
+ import logging
6
+ import time
7
+ import random
8
+ import re
9
+ import os
10
+ import json
11
+ import base64
12
+ from pathlib import Path
13
+ from typing import Dict, List, Tuple, Union, Callable, Any
14
+ from threading import Lock
15
+ from concurrent.futures import ThreadPoolExecutor
16
+ from src.llm.base import BaseLLM
17
+ from lxml import etree
18
+ from tqdm import tqdm
19
+ from collections import defaultdict
20
+ import bytedes
21
+ import html
22
+ from functools import lru_cache
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ def format_comments_as_xml(comments: List[Dict]) -> str:
27
+ """Format comments into XML string.
28
+
29
+ Args:
30
+ comments: List of comment dictionaries
31
+
32
+ Returns:
33
+ str: XML string representation of comments
34
+ """
35
+ logger.debug(f"Building XML input for {len(comments)} comments")
36
+ xml_comments = etree.Element("comments")
37
+
38
+ for comment in comments:
39
+ xml_comment = etree.SubElement(xml_comments, "comment")
40
+ # xml_comment_id = etree.SubElement(xml_comment, "comment_id")
41
+ # xml_comment_id.text = str(comment["comment_id"])
42
+ xml_comment_text = etree.SubElement(xml_comment, "comment_text")
43
+ xml_comment_text.text = comment["text"]
44
+
45
+ # Add video context if available
46
+ video_context = {k: v for k, v in comment.items() if
47
+ k in ['video_title', 'video_tag', 'video_description', 'text_in_video'] and
48
+ comment[k] is not None and comment[k] != ''}
49
+ if len(video_context) > 0:
50
+ xml_video_context = etree.SubElement(xml_comment, "video_context")
51
+ for key, value in video_context.items():
52
+ xml_video_context_key = etree.SubElement(xml_video_context, key)
53
+ xml_video_context_key.text = str(value)
54
+
55
+ xml_comments_str = etree.tostring(xml_comments, pretty_print=True, encoding="utf-8").decode("utf-8")
56
+ logger.debug(f"Generated XML: \n{xml_comments_str}")
57
+ return xml_comments_str
58
+
59
+
60
+ def format_convs_as_xml(convs: List[Dict]) -> str:
61
+ """Format convs into XML string.
62
+
63
+ Args:
64
+ convs: List of conv dictionaries
65
+
66
+ Returns:
67
+ str: XML string representation of convs
68
+ """
69
+ logger.debug(f"Building XML input for {len(convs)} convs")
70
+ xml_convs = etree.Element("convs")
71
+
72
+ for conv in convs:
73
+ alias2age_map = conv['alias2age_map']
74
+ alias2age_text = ", ".join([f'user_{k} is {v}' for k, v in alias2age_map.items()])
75
+ xml_conv = etree.SubElement(xml_convs, "conv")
76
+ xml_conv_id = etree.SubElement(xml_conv, "conversation_id")
77
+ xml_conv_id.text = str(conv["conversation_id"])
78
+ xml_conv_text = etree.SubElement(xml_conv, "conv_text")
79
+ xml_conv_text.text = conv["conv_text"]
80
+ xml_conv_ageinfo = etree.SubElement(xml_conv, "conv_ageinfo")
81
+ xml_conv_ageinfo.text = alias2age_text
82
+ xml_conv_region = etree.SubElement(xml_conv, "region")
83
+ xml_conv_region.text = conv["store_region"]
84
+ if conv.get("lang_fasttext", None) is not None:
85
+ xml_conv_language = etree.SubElement(xml_conv, "language")
86
+ xml_conv_language.text = conv["lang_fasttext"]
87
+
88
+ xml_convs_str = etree.tostring(xml_convs, pretty_print=True, encoding="utf-8").decode("utf-8")
89
+ logger.debug(f"Generated XML: \n{xml_convs_str}")
90
+ return xml_convs_str
91
+
92
+
93
+ @lru_cache(maxsize=int(os.environ.get('IMAGE_CACHE_SIZE', '512')))
94
+ def image_to_base64(image_path: str) -> str:
95
+ with open(image_path, "rb") as f:
96
+ return base64.b64encode(f.read()).decode('utf-8')
97
+
98
+
99
+ def text_with_placeholders_to_content(
100
+ text: str,
101
+ image_paths: List[Union[str, None]],
102
+ placeholder_pattern: str = r'\[IMG(\d+)\]'
103
+ ) -> List[Dict]:
104
+ """
105
+ text: 原始文本,含占位符如 [IMG1]、[IMG2]等
106
+ image_paths: 图片路径列表,顺序与占位符编号一致(占位符从1开始)
107
+ placeholder_pattern: 占位符正则
108
+ 返回: OpenAI API支持的content列表(text + image_url混排)
109
+ """
110
+ content: List[Dict] = []
111
+ pos = 0
112
+ for match in re.finditer(placeholder_pattern, text):
113
+ start, end = match.span()
114
+ img_idx = int(match.group(1)) - 1 # 占位符编号从1开始
115
+ # 前面的文本
116
+ if start > pos:
117
+ sub_text = text[pos:start]
118
+ if sub_text.strip():
119
+ content.append({"type": "text", "text": sub_text})
120
+ # 插入图片
121
+ if 0 <= img_idx < len(image_paths) and image_paths[img_idx] is not None:
122
+ content.append({
123
+ "type": "image_url",
124
+ "image_url": {
125
+ "url": f"data:image/jpeg;base64,{image_to_base64(image_paths[img_idx])}"
126
+ }
127
+ })
128
+ pos = end
129
+ # 剩余文本
130
+ if pos < len(text):
131
+ sub_text = text[pos:]
132
+ if sub_text.strip():
133
+ content.append({"type": "text", "text": sub_text})
134
+ return content
135
+
136
+
137
+ def format_convs_as_xml_image(convs: List[Dict]) -> List[Dict]:
138
+ """
139
+ 支持多模态convs:用 [IMGn] 占位符标记图片,并返回OpenAI兼容的 content 列表。
140
+
141
+ 约定(尽量兼容现有数据):
142
+ - 单图:conv['img_path'] 或 conv['img']
143
+ - 多图:conv['image_paths'] / conv['img_paths'] / conv['imgs'] / conv['images'](list[str])
144
+
145
+ 模式(自动):
146
+ - **嵌入式**:若 conv['conv_text'] 中包含 [IMGk],则 k 按“该 conv 的 image_paths(从1开始)”索引,
147
+ 生成XML时会把这些 [IMGk] 重写为全局 [IMGn],并在最终 content 中按位置插入图片。
148
+ - **非嵌入式**:若 conv_text 不包含占位符,则会在 XML 中为该 conv 追加 <image>[IMGn]</image> 节点。
149
+ """
150
+ logger.debug(f"Building multimodal XML input for {len(convs)} convs")
151
+ xml_convs = etree.Element("convs")
152
+
153
+ image_paths_flat: List[Union[str, None]] = []
154
+ img_counter = 0 # 全局图片序号,从1开始
155
+
156
+ def _extract_image_paths(conv: Dict) -> List[str]:
157
+ paths: List[str] = []
158
+ single = conv.get('img_path', conv.get('img', None))
159
+ if isinstance(single, str) and single:
160
+ paths.append(single)
161
+ for k in ("image_paths", "img_paths", "imgs", "images"):
162
+ v = conv.get(k, None)
163
+ if isinstance(v, list):
164
+ for p in v:
165
+ if isinstance(p, str) and p:
166
+ paths.append(p)
167
+ return paths
168
+
169
+ for conv in convs:
170
+ alias2age_map = conv['alias2age_map']
171
+ alias2age_text = ", ".join([f'user_{k} is {v}' for k, v in alias2age_map.items()])
172
+ xml_conv = etree.SubElement(xml_convs, "conv")
173
+ xml_conv_id = etree.SubElement(xml_conv, "conversation_id")
174
+ xml_conv_id.text = str(conv["conversation_id"])
175
+ xml_conv_text = etree.SubElement(xml_conv, "conv_text")
176
+ conv_text = conv["conv_text"]
177
+ xml_conv_ageinfo = etree.SubElement(xml_conv, "conv_ageinfo")
178
+ xml_conv_ageinfo.text = alias2age_text
179
+ xml_conv_region = etree.SubElement(xml_conv, "region")
180
+ xml_conv_region.text = conv["store_region"]
181
+ if conv.get("lang_fasttext", None) is not None:
182
+ xml_conv_language = etree.SubElement(xml_conv, "language")
183
+ xml_conv_language.text = conv["lang_fasttext"]
184
+
185
+ img_paths = _extract_image_paths(conv)
186
+ # If conv_text has placeholders, treat them as *local indices* into img_paths and rewrite to global ids.
187
+ placeholder_pattern = r'\[IMG(\d+)\]'
188
+ if isinstance(conv_text, str) and re.search(placeholder_pattern, conv_text):
189
+ local_to_global: Dict[int, int] = {}
190
+
191
+ def _replace(m: re.Match) -> str:
192
+ nonlocal img_counter
193
+ try:
194
+ local_idx = int(m.group(1)) - 1 # local placeholders are 1-based
195
+ except Exception:
196
+ return m.group(0)
197
+ if not (0 <= local_idx < len(img_paths)):
198
+ return m.group(0)
199
+ if local_idx not in local_to_global:
200
+ img_counter += 1
201
+ local_to_global[local_idx] = img_counter
202
+ image_paths_flat.append(img_paths[local_idx])
203
+ return f'[IMG{local_to_global[local_idx]}]'
204
+
205
+ xml_conv_text.text = re.sub(placeholder_pattern, _replace, conv_text)
206
+ else:
207
+ # Non-embedded: keep text unchanged and append <image> nodes
208
+ xml_conv_text.text = conv_text
209
+ # for p in img_paths:
210
+ # img_counter += 1
211
+ # image_paths_flat.append(p)
212
+ # xml_img = etree.SubElement(xml_conv, "image")
213
+ # xml_img.text = f'[IMG{img_counter}]'
214
+
215
+ xml_str = etree.tostring(xml_convs, pretty_print=True, encoding="utf-8").decode("utf-8")
216
+ logger.debug(f"Generated multimodal XML: \n{xml_str}")
217
+ return text_with_placeholders_to_content(xml_str, image_paths_flat)
218
+
219
+
220
+ def parse_unit_judgment(text_info: dict) -> int:
221
+ """Parse the judgment from LLM response for units.
222
+
223
+ Args:
224
+ text_info: dict
225
+ - text: Response text from LLM
226
+ - unit: unit the LLM is processing
227
+
228
+ Returns:
229
+ int: 1 if Unit A > Unit B, -1 if Unit A < Unit B, 0 if equal
230
+
231
+ Raises:
232
+ ValueError: If judgment cannot be parsed
233
+ """
234
+ text = text_info.get("text")
235
+ unit = text_info.get("unit", "comment")
236
+ m_res = re.search(r'<\s*result\s*>(.*?)<\s*/\s*result\s*>',
237
+ text, re.IGNORECASE | re.DOTALL)
238
+ if not m_res:
239
+ logger.error("No result block found in LLM response")
240
+ raise ValueError("No result block found")
241
+ result_body = m_res.group(1)
242
+
243
+ m_j = re.search(r'<\s*judgment\s*>(.*?)<\s*/\s*judgment\s*>',
244
+ result_body, re.IGNORECASE | re.DOTALL)
245
+ if not m_j:
246
+ logger.error("No judgment found in result block")
247
+ raise ValueError("No judgment found")
248
+ judgment_text = html.unescape(m_j.group(1).strip())
249
+
250
+ judgment_text = re.sub(r'\bFinal\s+Judgment\b', '',
251
+ judgment_text, flags=re.IGNORECASE).strip()
252
+
253
+ judgment_text = (judgment_text
254
+ .replace(":", ":")
255
+ .replace(">", ">")
256
+ .replace("<", "<")
257
+ .replace("=", "="))
258
+
259
+ # m_cmp = re.search(
260
+ # r'(?i)comment\s*([ab])\s*([<>]=?|==?)\s*comment\s*([ab])',
261
+ # judgment_text
262
+ # )
263
+ # Extracts content in between <judgement></judgement>
264
+ pattern = rf'(?i){re.escape(unit)}\s*([ab])\s*([<>]=?|==?)\s*{re.escape(unit)}\s*([ab])'
265
+ m_cmp = re.search(pattern, judgment_text)
266
+ if not m_cmp:
267
+ logger.error(f"Invalid judgment format: {judgment_text!r}")
268
+ raise ValueError(f"Invalid judgment format: {judgment_text!r}")
269
+
270
+ left, op, right = m_cmp.groups()
271
+ left, right = left.upper(), right.upper()
272
+
273
+ # Extract the violative message
274
+ violative_msg_dict = {}
275
+ violative_content = re.search(r'<violative_messages>(.*?)</violative_messages>', text, re.DOTALL)
276
+ if violative_content:
277
+ content = violative_content.group(1).strip()
278
+ # Extract tags and their contents (<A>m1,m2</A>)
279
+ matches = re.findall(r'<([A-Z])>(.*?)</\1>', content)
280
+ violative_msg_dict = {tag: val.strip().strip("<>[]\"").split(",") for tag, val in matches}
281
+ else:
282
+ logger.warning("No violative_messages found.")
283
+
284
+ label = -1
285
+ if op == '>':
286
+ label = 1 if left == 'A' else -1
287
+ elif op == '<':
288
+ label = -1 if left == 'A' else 1
289
+ elif op in ('=', '=='):
290
+ label = 0
291
+ else:
292
+ logger.error(f"Unsupported operator in judgment: {op!r}")
293
+ raise ValueError(f"Unsupported operator: {op!r}")
294
+
295
+ return label, violative_msg_dict
296
+
297
+ class PairwiseComparison:
298
+ """
299
+ Handles pairwise comparisons between items using LLM.
300
+ Manages comparison results, counts, and statistics.
301
+ """
302
+
303
+ def __init__(self,
304
+ llm: BaseLLM,
305
+ prompt_templates: Dict[str, str],
306
+ format_items: Callable[[List[Dict]], Any],
307
+ parse_judgment: Callable[[str], int],
308
+ data_dir: str,
309
+ es_index: str,
310
+ es_psm: str = "byte.es.ranking_moderation_cmt.service.my",
311
+ business: str = "comment",
312
+ max_comparisons_per_pair: int = 3,
313
+ max_workers: int = 4,
314
+ max_retries: int = 8,
315
+ initial_retry_delay: float = 2.0,
316
+ max_backups: int = 3,
317
+ detect_msg_violations: bool = True,
318
+ log_gpt_io: bool = False,
319
+ log_sample_rate: float = 0.01,
320
+ local_cache_path: Union[str, Path, None] = None,
321
+ local_cache_enabled: bool = False):
322
+ """Initialize pairwise comparison handler.
323
+
324
+ Args:
325
+ llm: Language model for pairwise comparisons
326
+ prompt_templates: Dictionary of prompt templates
327
+ format_items: Function to format items into LLM input. Can return either str or List[Dict]
328
+ parse_judgment: Function to parse LLM response into judgment
329
+ data_dir: Directory to store state files
330
+ max_comparisons_per_pair: Maximum number of comparisons for each pair
331
+ max_workers: Maximum number of worker processes
332
+ max_retries: Maximum number of retry attempts for failed comparisons
333
+ initial_retry_delay: Initial delay in seconds before first retry
334
+ max_backups: Maximum number of backup files to keep (default: 3)
335
+ """
336
+ self.llm = llm
337
+ self.prompt_templates = prompt_templates
338
+ self.format_items = format_items
339
+ self.parse_judgment = parse_judgment
340
+ self.max_comparisons_per_pair = max_comparisons_per_pair
341
+ self.max_workers = max_workers
342
+ self.max_retries = max_retries
343
+ self.initial_retry_delay = initial_retry_delay
344
+ self.max_backups = max_backups
345
+ self.data_dir = Path(data_dir)
346
+ self.es_index = es_index
347
+ if self.es_index == 'None':
348
+ self.es_index = None
349
+ if local_cache_path == 'None':
350
+ local_cache_path = None
351
+ self.data_dir.mkdir(parents=True, exist_ok=True)
352
+ self.local_cache_enabled = bool(local_cache_enabled or local_cache_path)
353
+ self.local_cache_path = None
354
+ if self.local_cache_enabled:
355
+ if local_cache_path is None or local_cache_path is True:
356
+ self.local_cache_path = self.data_dir / "pairwise_comparisons.jsonl"
357
+ else:
358
+ self.local_cache_path = Path(local_cache_path)
359
+ self.local_cache_path.parent.mkdir(parents=True, exist_ok=True)
360
+ self._local_cache_loaded = False
361
+
362
+ if self.es_index is not None:
363
+ self.client = bytedes.make_client(psm=es_psm, cluster="data",scheme="https",
364
+ verify_certs=False, use_ssl=True, ssl_show_warn=False, maxsize=50)
365
+ else:
366
+ self.client = None
367
+
368
+ if business.lower() == "comment":
369
+ self.process_unit = "comment"
370
+ elif business.lower() in ["dm", "dm_mm"]:
371
+ self.process_unit = "conversation"
372
+ else:
373
+ raise NotImplementedError(f"Ranking moderation for business {business} is not implemented!")
374
+
375
+ # Initialize state with thread-safe data structures
376
+ self._cache_lock = Lock()
377
+ self.comparison_results = defaultdict(list) # Store pairwise comparison results
378
+
379
+ self.detect_msg_violations = detect_msg_violations
380
+ self.log_gpt_io = log_gpt_io
381
+ self.log_sample_rate = log_sample_rate
382
+
383
+ def get_state_file_path(self, filename: str) -> Path:
384
+ """Get the full path for a state file.
385
+
386
+ Args:
387
+ filename: Name of the state file
388
+
389
+ Returns:
390
+ Path: Full path to the state file
391
+ """
392
+ return self.data_dir / filename
393
+
394
+ def get_pair_key(self, item_id1: str, item_id2: str) -> str:
395
+ """Get unique key for a pair of items.
396
+
397
+ Args:
398
+ item_id1: First item ID
399
+ item_id2: Second item ID
400
+
401
+ Returns:
402
+ str: Unique key for the pair
403
+ """
404
+ item_id1, item_id2 = str(item_id1), str(item_id2)
405
+ return f"{min(item_id1, item_id2)}_{max(item_id1, item_id2)}"
406
+
407
+ def get_ordered_pair(self, item_id1: str, item_id2: str) -> Tuple[str, str]:
408
+ """Get ordered pair of item IDs.
409
+
410
+ Args:
411
+ item_id1: First item ID
412
+ item_id2: Second item ID
413
+
414
+ Returns:
415
+ Tuple[str, str]: Ordered pair of item IDs (min, max)
416
+ """
417
+ return min(item_id1, item_id2), max(item_id1, item_id2)
418
+
419
+ def get_compare_result_from_es(self, pair_key: str) -> str:
420
+ # read from cache
421
+ with self._cache_lock:
422
+ if pair_key in self.comparison_results:
423
+ return self.comparison_results[pair_key]
424
+
425
+ if self.es_index is None:
426
+ if self.local_cache_enabled and not self._local_cache_loaded:
427
+ item_ids = None
428
+ try:
429
+ id1, id2 = pair_key.split("_", 1)
430
+ item_ids = [id1, id2]
431
+ except Exception:
432
+ item_ids = None
433
+ self.load_data_to_cache_from_local(item_ids, load_detail=True)
434
+ with self._cache_lock:
435
+ if pair_key in self.comparison_results:
436
+ return self.comparison_results[pair_key]
437
+ return []
438
+ trial = 0
439
+ current_delay = self.initial_retry_delay
440
+
441
+ while trial < self.max_retries:
442
+ try:
443
+ query_body = {
444
+ "query": {
445
+ "term": {
446
+ "pair_key": {
447
+ "value": pair_key
448
+ }
449
+ }
450
+ }
451
+ }
452
+ result = self.client.search(index=self.es_index, body=query_body, size=200)
453
+ compare_result = []
454
+ for hit in result['hits']['hits']:
455
+ r = {
456
+ "judgment": hit['_source']['judgment'],
457
+ "raw_response": hit['_source']['raw_response'],
458
+ "timestamp": hit['_source']["timestamp"],
459
+ "item_id_a": hit['_source']['item_id_a'],
460
+ "item_id_b": hit['_source']['item_id_b'],
461
+ 'ordered_ids': self.get_ordered_pair(hit['_source']['item_id_a'], hit['_source']['item_id_b']),
462
+ 'original_ids': (hit['_source']['item_id_a'], hit['_source']['item_id_b'])
463
+ }
464
+ if self.detect_msg_violations:
465
+ r["violative_msg_map_str"] = hit['_source']['violative_msg_map_str']
466
+
467
+ compare_result.append(r)
468
+ with self._cache_lock:
469
+ self.comparison_results[pair_key] = compare_result
470
+
471
+ return compare_result
472
+
473
+ except Exception as e:
474
+ trial += 1
475
+ if trial == self.max_retries:
476
+ logger.error(f"Failed to get data from es after {self.max_retries} attempts: {str(e)}")
477
+ raise
478
+ logger.warning(f"Failed to get data from es, Request failed (attempt {trial}/{self.max_retries}). "
479
+ f"Retrying in {current_delay} seconds. Error: {str(e)}")
480
+ time.sleep(current_delay)
481
+ current_delay *= 2
482
+
483
+ def load_data_to_cache_from_es(self, item_ids: List[str], load_detail=True):
484
+ """Load all data from ES."""
485
+ # read from cache
486
+ if self.es_index is None:
487
+ return self.load_data_to_cache_from_local(item_ids, load_detail=load_detail)
488
+ def chunk_list(lst, chunk_size):
489
+ for i in range(0, len(lst), chunk_size):
490
+ yield lst[i:i + chunk_size]
491
+
492
+ comparison_results_temp = {}
493
+
494
+ scroll = "2m"
495
+ batch_size = 500
496
+
497
+ _source = ["pair_key", "judgment", "timestamp", "item_id_a", "item_id_b"]
498
+ if self.detect_msg_violations:
499
+ _source.append("violative_msg_map_str")
500
+
501
+ if load_detail:
502
+ _source.append("raw_response")
503
+
504
+ for id_batch in chunk_list(item_ids, batch_size):
505
+ query = {
506
+ "_source": _source,
507
+ "query": {
508
+ "bool": {
509
+ "should": [
510
+ { "terms": { "item_id_a.keyword": id_batch }},
511
+ { "terms": { "item_id_b.keyword": id_batch }}
512
+ ],
513
+ "minimum_should_match": 1
514
+ }
515
+ }
516
+ }
517
+ page = self.client.search(index=self.es_index, body=query, scroll=scroll, size=5000)
518
+ sid = page["_scroll_id"]
519
+ scroll_size = len(page["hits"]["hits"])
520
+ while scroll_size > 0:
521
+ for hit in page["hits"]["hits"]:
522
+ pair_key = hit['_source']['pair_key']
523
+ r = {
524
+ "judgment": hit['_source']['judgment'],
525
+ "timestamp": hit['_source']["timestamp"],
526
+ "item_id_a": hit['_source']['item_id_a'],
527
+ "item_id_b": hit['_source']['item_id_b'],
528
+ 'ordered_ids': self.get_ordered_pair(hit['_source']['item_id_a'], hit['_source']['item_id_b']),
529
+ 'original_ids': (hit['_source']['item_id_a'], hit['_source']['item_id_b'])
530
+ }
531
+ if self.detect_msg_violations:
532
+ r["violative_msg_map_str"] = hit['_source']['violative_msg_map_str']
533
+
534
+ if 'raw_response' in hit['_source']:
535
+ r['raw_response'] = hit['_source']['raw_response']
536
+ exists = False
537
+ with self._cache_lock:
538
+ if pair_key in self.comparison_results:
539
+ for comparison_result in self.comparison_results[pair_key]:
540
+ if r['timestamp'] == comparison_result['timestamp']:
541
+ exists = True
542
+ break
543
+ if not exists:
544
+ if pair_key not in self.comparison_results:
545
+ self.comparison_results[pair_key] = []
546
+ self.comparison_results[pair_key].append(r)
547
+ # save a temp relate to this batch
548
+ exists = False
549
+ if pair_key in comparison_results_temp:
550
+ for comparison_result in comparison_results_temp[pair_key]:
551
+ if r['timestamp'] == comparison_result['timestamp']:
552
+ exists = True
553
+ break
554
+ if not exists:
555
+ if pair_key not in comparison_results_temp:
556
+ comparison_results_temp[pair_key] = []
557
+ comparison_results_temp[pair_key].append(r)
558
+ page = self.client.scroll(scroll_id=sid, scroll=scroll)
559
+ sid = page["_scroll_id"]
560
+ scroll_size = len(page["hits"]["hits"])
561
+ return comparison_results_temp
562
+
563
+ def load_data_to_cache_from_local(self, item_ids: Union[List[str], None], load_detail=True):
564
+ if not self.local_cache_enabled or self.local_cache_path is None:
565
+ return {}
566
+ if not self.local_cache_path.exists():
567
+ self._local_cache_loaded = True
568
+ return {}
569
+ item_id_set = None
570
+ if item_ids is not None:
571
+ item_id_set = set(str(x) for x in item_ids)
572
+ comparison_results_temp = {}
573
+ with self._cache_lock:
574
+ self._local_cache_loaded = True
575
+ with open(self.local_cache_path, "r", encoding="utf-8") as f:
576
+ for line in f:
577
+ line = line.strip()
578
+ if not line:
579
+ continue
580
+ try:
581
+ doc = json.loads(line)
582
+ except Exception:
583
+ continue
584
+ if "item_id_a" not in doc or "item_id_b" not in doc:
585
+ continue
586
+ item_id_a = str(doc["item_id_a"])
587
+ item_id_b = str(doc["item_id_b"])
588
+ if item_id_set is not None and item_id_a not in item_id_set and item_id_b not in item_id_set:
589
+ continue
590
+ pair_key = doc.get("pair_key", self.get_pair_key(item_id_a, item_id_b))
591
+ r = {
592
+ "judgment": doc.get("judgment"),
593
+ "timestamp": doc.get("timestamp"),
594
+ "item_id_a": item_id_a,
595
+ "item_id_b": item_id_b,
596
+ "ordered_ids": self.get_ordered_pair(item_id_a, item_id_b),
597
+ "original_ids": (item_id_a, item_id_b)
598
+ }
599
+ if self.detect_msg_violations and "violative_msg_map_str" in doc:
600
+ r["violative_msg_map_str"] = doc["violative_msg_map_str"]
601
+ if load_detail and "raw_response" in doc:
602
+ r["raw_response"] = doc["raw_response"]
603
+ exists = False
604
+ with self._cache_lock:
605
+ if pair_key in self.comparison_results:
606
+ for comparison_result in self.comparison_results[pair_key]:
607
+ if r["timestamp"] == comparison_result.get("timestamp"):
608
+ exists = True
609
+ break
610
+ if not exists:
611
+ if pair_key not in self.comparison_results:
612
+ self.comparison_results[pair_key] = []
613
+ self.comparison_results[pair_key].append(r)
614
+ exists = False
615
+ if pair_key in comparison_results_temp:
616
+ for comparison_result in comparison_results_temp[pair_key]:
617
+ if r["timestamp"] == comparison_result.get("timestamp"):
618
+ exists = True
619
+ break
620
+ if not exists:
621
+ if pair_key not in comparison_results_temp:
622
+ comparison_results_temp[pair_key] = []
623
+ comparison_results_temp[pair_key].append(r)
624
+ return comparison_results_temp
625
+
626
+ def write_compare_result_to_es(self, pair_key: str, comparison_result):
627
+
628
+ trial = 0
629
+ current_delay = self.initial_retry_delay
630
+
631
+ while trial < self.max_retries:
632
+ try:
633
+ doc = {}
634
+ doc['pair_key'] = pair_key
635
+ doc['judgment'] = comparison_result['judgment']
636
+ doc['raw_response'] = comparison_result['raw_response']
637
+ doc['timestamp'] = comparison_result['timestamp']
638
+ doc['item_id_a'] = comparison_result['original_ids'][0]
639
+ doc['item_id_b'] = comparison_result['original_ids'][1]
640
+ if self.detect_msg_violations:
641
+ doc['violative_msg_map_str'] = comparison_result['violative_msg_map_str']
642
+
643
+ with self._cache_lock:
644
+ if pair_key in self.comparison_results:
645
+ self.comparison_results[pair_key].append(comparison_result)
646
+ else:
647
+ self.comparison_results[pair_key] = [comparison_result]
648
+ if self.local_cache_enabled and self.local_cache_path is not None:
649
+ with open(self.local_cache_path, "a", encoding="utf-8") as f:
650
+ f.write(json.dumps(doc, ensure_ascii=False) + "\n")
651
+ if self.es_index is not None:
652
+ return self.client.index(index=self.es_index, body=doc)
653
+ else:
654
+ return
655
+
656
+ except Exception as e:
657
+ trial += 1
658
+ if trial == self.max_retries:
659
+ logger.error(f"Failed to get data from es after {self.max_retries} attempts: {str(e)}")
660
+ raise
661
+ logger.warning(f"Failed to get data from es, Request failed (attempt {trial}/{self.max_retries}). "
662
+ f"Retrying in {current_delay} seconds. Error: {str(e)}")
663
+ time.sleep(current_delay)
664
+ current_delay *= 2
665
+
666
+
667
+ def compare_single_pair(self, pair: Tuple[Dict, Dict], random_swap: bool = True, use_cache: bool = True) -> Dict[str, str]:
668
+ """Compare a single pair of items and store the result.
669
+
670
+ Args:
671
+ pair: Tuple of (item1, item2)
672
+ random_swap: Whether to randomly swap the order of items.
673
+ If False, keep the original order.
674
+ Default is True for backward compatibility.
675
+ use_cache: Whether to use cached comparison results if available.
676
+ If False, always perform new comparison.
677
+ Default is True for backward compatibility.
678
+ """
679
+ item1, item2 = pair
680
+ # Only swap items if random_swap is True
681
+ if random_swap and random.random() < 0.5:
682
+ item1, item2 = item2, item1
683
+
684
+ item_id1 = str(item1['item_id'])
685
+ item_id2 = str(item2['item_id'])
686
+ pair_key = self.get_pair_key(item_id1, item_id2)
687
+ ordered_id1, ordered_id2 = self.get_ordered_pair(item_id1, item_id2)
688
+
689
+ # get result from es
690
+ if use_cache:
691
+ compare_result = self.get_compare_result_from_es(pair_key)
692
+ if len(compare_result) >= self.max_comparisons_per_pair:
693
+ return random.choice(compare_result)
694
+
695
+ # Format items into LLM input
696
+ formatted_input = self.format_items([item1, item2])
697
+
698
+ messages = [
699
+ {"role": "system", "content": self.prompt_templates['system_prompt']},
700
+ {"role": "user", "content": formatted_input}
701
+ ]
702
+
703
+ trial = 0
704
+ current_delay = self.initial_retry_delay
705
+
706
+ while trial < self.max_retries:
707
+ response = None
708
+ try:
709
+ response = self.llm.chat_completion(messages=messages, max_tokens=4000, temperature=0.0)
710
+ result = response['choices'][0]['message']['content']
711
+ result_info = {
712
+ "text": result,
713
+ "unit": self.process_unit
714
+ }
715
+ judgment, violative_msg_dict = self.parse_judgment(result_info)
716
+
717
+ comparison_result = {
718
+ 'judgment': judgment,
719
+ 'raw_response': result,
720
+ 'timestamp': time.time(),
721
+ 'ordered_ids': (ordered_id1, ordered_id2),
722
+ 'original_ids': (item_id1, item_id2),
723
+ "item_id_a": str(item_id1),
724
+ "item_id_b": str(item_id2)
725
+ }
726
+ if self.detect_msg_violations:
727
+ # e.g. {"a": [m0,m1], "b": [m0]}
728
+ comparison_result['violative_msg_map_str'] = json.dumps(violative_msg_dict)
729
+
730
+ if self.detect_msg_violations and self.log_gpt_io and random.random()<self.log_sample_rate:
731
+ with self._cache_lock:
732
+ with open(self.data_dir / "tmp_gpt_io.jsonl", "a") as f:
733
+ if isinstance(formatted_input, str):
734
+ formatted_input_for_regex = formatted_input
735
+ else:
736
+ # multimodal content: concatenate all text blocks
737
+ formatted_input_for_regex = "".join(
738
+ c.get("text", "")
739
+ for c in formatted_input
740
+ if isinstance(c, dict) and c.get("type") == "text"
741
+ )
742
+ _convs = re.findall(r'<conv_text>(.*?)</conv_text>', formatted_input_for_regex.strip(), re.DOTALL)
743
+ _convs = [html.unescape(c).strip() for c in _convs]
744
+ _age_info = re.findall(r'<conv_ageinfo>(.*?)</conv_ageinfo>', formatted_input_for_regex, re.DOTALL)
745
+ _data = {
746
+ 'timestamp': time.time(),
747
+ "conversations": [{"conv_text": c, "conv_ageinfo": a} for c,a in zip(_convs, _age_info)],
748
+ "judgment": judgment,
749
+ "violative_msg_map_str": violative_msg_dict
750
+ }
751
+ f.write(json.dumps(_data, ensure_ascii=False)+"\n")
752
+
753
+ self.write_compare_result_to_es(pair_key, comparison_result)
754
+ return comparison_result
755
+ except Exception as e:
756
+ trial += 1
757
+ if str(e) == "No result block found" and trial == 2:
758
+ logger.error(f"parse error, only try 2 times")
759
+ return
760
+ if trial == self.max_retries:
761
+ logger.error(f"Failed to compare pair after {self.max_retries} attempts: {str(e)}")
762
+ raise
763
+ logger.warning(f"Request failed (attempt {trial}/{self.max_retries}). "
764
+ f"Retrying in {current_delay} seconds. Error: {str(e)}")
765
+ if response is not None:
766
+ logger.warning(f"Response: {response}")
767
+ time.sleep(current_delay)
768
+ current_delay *= 2
769
+
770
+ def compare_pairs(self, pairs: List[Tuple[Dict, Dict]], random_swap: bool = True, use_cache: bool = True, use_tqdm: bool = True) -> List[Tuple[Dict, Dict]]:
771
+ """Compare multiple pairs of items in parallel.
772
+
773
+ Args:
774
+ pairs: List of (item1, item2) tuples to compare
775
+ random_swap: Whether to randomly swap the order of items in each pair.
776
+ If False, keep the original order.
777
+ Default is True for backward compatibility.
778
+ use_cache: Whether to use cached comparison results if available.
779
+ If False, always perform new comparison.
780
+ Default is True for backward compatibility.
781
+
782
+ Returns:
783
+ List of successfully compared pairs
784
+ """
785
+ # Create a new process pool for each batch of comparisons
786
+ with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
787
+ futures = []
788
+ for pair in pairs:
789
+ future = executor.submit(self.compare_single_pair, pair, random_swap=random_swap, use_cache=use_cache)
790
+ futures.append((future, pair))
791
+
792
+ # Get successfully compared pairs
793
+ successful_pairs = []
794
+ for future, pair in tqdm(futures, desc="Comparing pairs", total=len(futures), leave=False, disable=not use_tqdm):
795
+ try:
796
+ result = future.result()
797
+ if result: # If comparison was successful
798
+ successful_pairs.append(pair)
799
+ except Exception as e:
800
+ logger.error(f"Error comparing pair: {str(e)}")
801
+
802
+ return successful_pairs
803
+
804
+ def understand_by_pairs(self, item: Dict, compare_results: List[Dict]):
805
+ from lxml import etree
806
+ import random
807
+ rs_high, rs_low = [], []
808
+ target_id = str(item["item_id"])
809
+ for pair_key, r in compare_results.items():
810
+ r = r[0]
811
+ a, b, j = r["item_id_a"], r["item_id_b"], r["judgment"]
812
+ if target_id == a and j >= 0:
813
+ rs_low.append(r)
814
+ elif target_id == b and j < 0:
815
+ rs_low.append(r)
816
+ elif target_id == a and j < 0:
817
+ rs_high.append(r)
818
+ elif target_id == b and j >= 0:
819
+ rs_high.append(r)
820
+ random.shuffle(rs_high)
821
+ random.shuffle(rs_low)
822
+ examples = rs_high[:2] + rs_low[:2]
823
+
824
+ root = etree.Element("analysis_input")
825
+
826
+ target_block = etree.SubElement(root, "target_conversation")
827
+ xml_target = etree.fromstring(
828
+ self.format_items([item]).encode("utf-8")
829
+ )
830
+ target_block.append(xml_target)
831
+
832
+ comps_block = etree.SubElement(root, "comparisons")
833
+ for ex in examples:
834
+ # try:
835
+ # comp_xml = etree.fromstring(ex["raw_response"].encode("utf-8"))
836
+ # wrapped = etree.SubElement(comps_block, "comparison")
837
+ # wrapped.append(comp_xml)
838
+ # except Exception:
839
+ txt = etree.SubElement(comps_block, "raw_response")
840
+ txt.text = etree.CDATA(ex["raw_response"])
841
+
842
+ xml_input = etree.tostring(root, pretty_print=True,
843
+ encoding="utf-8").decode("utf-8")
844
+
845
+ system_prompt = """You are an expert in content safety and moderation.
846
+
847
+ You will receive a set of pairwise comparison results between a **target conversation** and multiple other conversations. Your primary goal is to clearly inform users about the **exact meaning and severity** of the target conversation based solely on these comparisons. Avoid any external assumptions or information.
848
+
849
+ When crafting your analysis, follow these guidelines:
850
+
851
+ * Clearly interpret the **exact meaning** of the target conversation, thoroughly and explicitly derived from the provided comparisons. Provide sufficient details so users clearly grasp the context, intention, and potential implications of the conversation.
852
+ * Identify explicitly which conversations the target conversation is **more severe than** and which ones it is **less severe than**. Clearly quote or paraphrase each comparison conversation, and provide explicit reasons derived from comparison data.
853
+ * Ensure all explanations rely **only** on the provided comparison data—**do not speculate or infer beyond what is explicitly given**.
854
+ * Structure your response clearly and flexibly so users immediately understand:
855
+
856
+ * The precise and detailed meaning of the target conversation.
857
+ * The explicit reasons behind the conversation's classification as severe or non-severe.
858
+ * Specific examples of comparisons supporting your severity assessments, clearly quoting or paraphrasing each messages in the conversation.
859
+
860
+ Your output should be in **Markdown format**, providing clear bilingual summaries in **English and Chinese** to ensure consistency for bilingual users.
861
+
862
+ *This detailed analysis aims to provide users with a clear and explicit understanding of the exact meaning of the conversation, its severity, and concrete evidence from comparative conversation or messages.* make sure your response can achieve this
863
+
864
+
865
+ ---
866
+
867
+ **Detailed Output Example:**
868
+
869
+ ## 📝 Conversation Analysis
870
+
871
+ ### 🔹 English
872
+
873
+ **Direct translation:**
874
+ Translate this conversation into English. If the conversation is in English, simply copy the conversation as the output.
875
+ - Maintain all special tags message ids e.g. <m1> and user_ids e.g. user_0
876
+ - Do not remove existing new line characters
877
+ - Explain any cultural references / slang that could be unsafe or harmful in line e.g. user_0: Show me your "taye" [Note: "taye" means butt/ass in Algerian slang]
878
+ - Provide cultural context (within brackets inline - e.g. [Note: xxx]) when the translation in English sounds unnatural/incoherent
879
+ - Keep the explanations brief and concise (Keep a single note to a maximum of 10 words).
880
+
881
+ **Interpreted Meaning:**
882
+ The conversation explicitly expresses [detailed interpretation explicitly based on provided comparisons, clearly explaining context, intent, and implications].
883
+
884
+ **Severity Assessment:**
885
+
886
+ * **More Severe Than:**
887
+
888
+ * "Exact or paraphrased conversation A": because \[explicit reason from comparison clearly indicating severity difference].
889
+ * "Exact or paraphrased conversation B": because \[explicit reason from comparison clearly indicating severity difference].
890
+
891
+ * **Less Severe Than:**
892
+
893
+ * "Exact or paraphrased conversation C": because \[explicit reason from comparison clearly indicating severity difference].
894
+
895
+ ---
896
+
897
+ ### 🔸 中文
898
+
899
+ **直接翻译:**
900
+ 将此对话翻译成中文。如果对话是中文,只需复制对话作为输出即可。
901
+ - 保留所有特殊标签消息 ID,例如 <m1> 和用户 ID,例如 user_0
902
+ - 不要删除现有的换行符
903
+ - 解释任何可能不安全或有害的文化指涉/俚语,例如 user_0:给我看看你的“taye”[注:“taye”在阿尔及利亚俚语中是屁股的意思]
904
+ - 如果英语翻译听起来不自然/语无伦次,请提供文化背景(在括号内,例如 [注:xxx])。
905
+ - 解释请简洁明了(每条注释最多 10 个字)。
906
+
907
+ **解释性含义**:
908
+ 对话明确表达了\[基于提供的比较结果的详细解释,清晰地解释了上下文、意图和含义]。
909
+
910
+ **严重程度评估**:
911
+
912
+ * **严重程度高于**:
913
+
914
+ * “对话 A 完全一致或转述”:因为\[比较结果的明确原因清楚地表明了严重程度的差异]。
915
+ * “对话 B 完全一致或转述”:因为\[比较结果的明确原因清楚地表明了严重程度的差异]。
916
+
917
+ * **严重程度低于**:
918
+
919
+ * “对话 C 完全一致或转述”:因为\[比较结果的明确原因清楚地表明了严重程度的差异]。
920
+
921
+ """
922
+ logger.debug(f"system_prompt: {system_prompt}")
923
+ logger.debug(f"user_prompt: {xml_input}")
924
+ messages = [
925
+ {"role": "system", "content": system_prompt},
926
+ {"role": "user", "content": xml_input}
927
+ ]
928
+ max_retries = self.max_retries
929
+ trial = 0
930
+ current_delay = self.initial_retry_delay
931
+ response = None
932
+ while trial < max_retries:
933
+ try:
934
+ response = self.llm.chat_completion(messages=messages, max_tokens=4000, temperature=0.0)
935
+ llm_understanding = response['choices'][0]['message']['content']
936
+ logger.debug(f"llm response: {llm_understanding}")
937
+ return llm_understanding
938
+ except Exception as e:
939
+ trial += 1
940
+ if trial == self.max_retries:
941
+ logger.error(f"Failed to get data by llm api, {self.max_retries} attempts: {str(e)}")
942
+ raise
943
+ logger.warning(f"Failed to get data by llm api, Request failed (attempt {trial}/{self.max_retries}). "
944
+ f"Retrying in {current_delay} seconds. Error: {str(e)}")
945
+ if response is not None:
946
+ logger.warning(f"Response={response}")
947
+ time.sleep(current_delay)
948
+ current_delay *= 2
949
+
950
+ def get_comparison_result_by_id(self, item_id: str) -> List[Dict]:
951
+ if self.es_index is None:
952
+ if self.local_cache_enabled and not self._local_cache_loaded:
953
+ self.load_data_to_cache_from_local(None, load_detail=True)
954
+ item_id = str(item_id)
955
+ results = []
956
+ with self._cache_lock:
957
+ for pair_results in self.comparison_results.values():
958
+ for r in pair_results:
959
+ if str(r.get("item_id_a")) == item_id or str(r.get("item_id_b")) == item_id:
960
+ results.append(r)
961
+ return results
962
+ trial = 0
963
+ current_delay = self.initial_retry_delay
964
+
965
+ while trial < self.max_retries:
966
+ try:
967
+ query_body = {
968
+ "query": {
969
+ "bool": {
970
+ "should": [
971
+ { "term": { "item_id_a.keyword": item_id }},
972
+ { "term": { "item_id_b.keyword": item_id }}
973
+ ],
974
+ "minimum_should_match": 1
975
+ }
976
+ }
977
+ }
978
+ result = self.client.search(index=self.es_index, body=query_body, size=200)
979
+ compare_result = []
980
+ for hit in result['hits']['hits']:
981
+ r = {
982
+ "judgment": hit['_source']['judgment'],
983
+ "raw_response": hit['_source']['raw_response'],
984
+ "timestamp": hit['_source']["timestamp"],
985
+ "item_id_a": hit['_source']['item_id_a'],
986
+ "item_id_b": hit['_source']['item_id_b'],
987
+ 'ordered_ids': self.get_ordered_pair(hit['_source']['item_id_a'], hit['_source']['item_id_b']),
988
+ 'original_ids': (hit['_source']['item_id_a'], hit['_source']['item_id_b'])
989
+
990
+ }
991
+ if self.detect_msg_violations:
992
+ r["violative_msg_map_str"] = hit['_source']['violative_msg_map_str']
993
+
994
+ compare_result.append(r)
995
+
996
+ return compare_result
997
+
998
+ except Exception as e:
999
+ trial += 1
1000
+ if trial == self.max_retries:
1001
+ logger.error(f"Failed to get data from es after {self.max_retries} attempts: {str(e)}")
1002
+ raise
1003
+ logger.warning(f"Failed to get data from es, Request failed (attempt {trial}/{self.max_retries}). "
1004
+ f"Retrying in {current_delay} seconds. Error: {str(e)}")
1005
+ time.sleep(current_delay)
1006
+ current_delay *= 2
1007
+
1008
+
1009
+ def get_comparison_count(self, pair_key: str) -> int:
1010
+ """Get the number of comparisons for a pair.
1011
+
1012
+ Args:
1013
+ pair_key: Key for the pair
1014
+
1015
+ Returns:
1016
+ int: Number of comparisons for the pair
1017
+ """
1018
+ return len(self.get_comparison_results(pair_key))
1019
+
1020
+ def get_comparison_results(self, pair_key: str) -> List[Dict]:
1021
+ """Get comparison results for a pair.
1022
+
1023
+ Args:
1024
+ pair_key: Key for the pair
1025
+
1026
+ Returns:
1027
+ List[Dict]: List of comparison results for the pair
1028
+ """
1029
+ return self.get_compare_result_from_es(pair_key)
1030
+
1031
+ def get_compare_information(self, item_ids: List[str], use_tqdm: bool = True) -> Dict[str, Any]:
1032
+ """Get comprehensive comparison information between specified items.
1033
+
1034
+ Args:
1035
+ item_ids: List of item IDs to get comparison information for
1036
+
1037
+ Returns:
1038
+ Dict containing:
1039
+ - comparison_results: Dict mapping pair keys to their comparison results
1040
+ - comparison_counts: Dict mapping item IDs to their total comparison counts
1041
+ - pair_comparison_counts: Dict mapping pair keys to their comparison counts
1042
+ """
1043
+
1044
+ comparison_results = {}
1045
+ comparison_counts = {}
1046
+ pair_comparison_counts = {}
1047
+ item_id_set = set(item_ids)
1048
+ with self._cache_lock:
1049
+ for pair_key, pair_results in tqdm(self.comparison_results.items(), desc="Scan Comparing pairs", total=len(self.comparison_results), leave=False, disable=not use_tqdm):
1050
+ id1, id2 = pair_key.split('_')
1051
+ if id1 in item_id_set and id2 in item_id_set:
1052
+ comparison_results[pair_key] = list(pair_results)
1053
+ comparison_counts[id1] = comparison_counts.get(id1, 0) + len(pair_results)
1054
+ comparison_counts[id2] = comparison_counts.get(id2, 0) + len(pair_results)
1055
+ pair_comparison_counts[pair_key] = len(pair_results)
1056
+ return {
1057
+ 'comparison_results': comparison_results,
1058
+ 'comparison_counts': comparison_counts,
1059
+ 'pair_comparison_counts': pair_comparison_counts
1060
+ }