Upload pairwise_comparison.py with huggingface_hub
Browse files- pairwise_comparison.py +1060 -0
pairwise_comparison.py
ADDED
|
@@ -0,0 +1,1060 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pairwise comparison handler for Bradley-Terry ranking model.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import time
|
| 7 |
+
import random
|
| 8 |
+
import re
|
| 9 |
+
import os
|
| 10 |
+
import json
|
| 11 |
+
import base64
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Dict, List, Tuple, Union, Callable, Any
|
| 14 |
+
from threading import Lock
|
| 15 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 16 |
+
from src.llm.base import BaseLLM
|
| 17 |
+
from lxml import etree
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
from collections import defaultdict
|
| 20 |
+
import bytedes
|
| 21 |
+
import html
|
| 22 |
+
from functools import lru_cache
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
def format_comments_as_xml(comments: List[Dict]) -> str:
|
| 27 |
+
"""Format comments into XML string.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
comments: List of comment dictionaries
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
str: XML string representation of comments
|
| 34 |
+
"""
|
| 35 |
+
logger.debug(f"Building XML input for {len(comments)} comments")
|
| 36 |
+
xml_comments = etree.Element("comments")
|
| 37 |
+
|
| 38 |
+
for comment in comments:
|
| 39 |
+
xml_comment = etree.SubElement(xml_comments, "comment")
|
| 40 |
+
# xml_comment_id = etree.SubElement(xml_comment, "comment_id")
|
| 41 |
+
# xml_comment_id.text = str(comment["comment_id"])
|
| 42 |
+
xml_comment_text = etree.SubElement(xml_comment, "comment_text")
|
| 43 |
+
xml_comment_text.text = comment["text"]
|
| 44 |
+
|
| 45 |
+
# Add video context if available
|
| 46 |
+
video_context = {k: v for k, v in comment.items() if
|
| 47 |
+
k in ['video_title', 'video_tag', 'video_description', 'text_in_video'] and
|
| 48 |
+
comment[k] is not None and comment[k] != ''}
|
| 49 |
+
if len(video_context) > 0:
|
| 50 |
+
xml_video_context = etree.SubElement(xml_comment, "video_context")
|
| 51 |
+
for key, value in video_context.items():
|
| 52 |
+
xml_video_context_key = etree.SubElement(xml_video_context, key)
|
| 53 |
+
xml_video_context_key.text = str(value)
|
| 54 |
+
|
| 55 |
+
xml_comments_str = etree.tostring(xml_comments, pretty_print=True, encoding="utf-8").decode("utf-8")
|
| 56 |
+
logger.debug(f"Generated XML: \n{xml_comments_str}")
|
| 57 |
+
return xml_comments_str
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def format_convs_as_xml(convs: List[Dict]) -> str:
|
| 61 |
+
"""Format convs into XML string.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
convs: List of conv dictionaries
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
str: XML string representation of convs
|
| 68 |
+
"""
|
| 69 |
+
logger.debug(f"Building XML input for {len(convs)} convs")
|
| 70 |
+
xml_convs = etree.Element("convs")
|
| 71 |
+
|
| 72 |
+
for conv in convs:
|
| 73 |
+
alias2age_map = conv['alias2age_map']
|
| 74 |
+
alias2age_text = ", ".join([f'user_{k} is {v}' for k, v in alias2age_map.items()])
|
| 75 |
+
xml_conv = etree.SubElement(xml_convs, "conv")
|
| 76 |
+
xml_conv_id = etree.SubElement(xml_conv, "conversation_id")
|
| 77 |
+
xml_conv_id.text = str(conv["conversation_id"])
|
| 78 |
+
xml_conv_text = etree.SubElement(xml_conv, "conv_text")
|
| 79 |
+
xml_conv_text.text = conv["conv_text"]
|
| 80 |
+
xml_conv_ageinfo = etree.SubElement(xml_conv, "conv_ageinfo")
|
| 81 |
+
xml_conv_ageinfo.text = alias2age_text
|
| 82 |
+
xml_conv_region = etree.SubElement(xml_conv, "region")
|
| 83 |
+
xml_conv_region.text = conv["store_region"]
|
| 84 |
+
if conv.get("lang_fasttext", None) is not None:
|
| 85 |
+
xml_conv_language = etree.SubElement(xml_conv, "language")
|
| 86 |
+
xml_conv_language.text = conv["lang_fasttext"]
|
| 87 |
+
|
| 88 |
+
xml_convs_str = etree.tostring(xml_convs, pretty_print=True, encoding="utf-8").decode("utf-8")
|
| 89 |
+
logger.debug(f"Generated XML: \n{xml_convs_str}")
|
| 90 |
+
return xml_convs_str
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@lru_cache(maxsize=int(os.environ.get('IMAGE_CACHE_SIZE', '512')))
|
| 94 |
+
def image_to_base64(image_path: str) -> str:
|
| 95 |
+
with open(image_path, "rb") as f:
|
| 96 |
+
return base64.b64encode(f.read()).decode('utf-8')
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def text_with_placeholders_to_content(
|
| 100 |
+
text: str,
|
| 101 |
+
image_paths: List[Union[str, None]],
|
| 102 |
+
placeholder_pattern: str = r'\[IMG(\d+)\]'
|
| 103 |
+
) -> List[Dict]:
|
| 104 |
+
"""
|
| 105 |
+
text: 原始文本,含占位符如 [IMG1]、[IMG2]等
|
| 106 |
+
image_paths: 图片路径列表,顺序与占位符编号一致(占位符从1开始)
|
| 107 |
+
placeholder_pattern: 占位符正则
|
| 108 |
+
返回: OpenAI API支持的content列表(text + image_url混排)
|
| 109 |
+
"""
|
| 110 |
+
content: List[Dict] = []
|
| 111 |
+
pos = 0
|
| 112 |
+
for match in re.finditer(placeholder_pattern, text):
|
| 113 |
+
start, end = match.span()
|
| 114 |
+
img_idx = int(match.group(1)) - 1 # 占位符编号从1开始
|
| 115 |
+
# 前面的文本
|
| 116 |
+
if start > pos:
|
| 117 |
+
sub_text = text[pos:start]
|
| 118 |
+
if sub_text.strip():
|
| 119 |
+
content.append({"type": "text", "text": sub_text})
|
| 120 |
+
# 插入图片
|
| 121 |
+
if 0 <= img_idx < len(image_paths) and image_paths[img_idx] is not None:
|
| 122 |
+
content.append({
|
| 123 |
+
"type": "image_url",
|
| 124 |
+
"image_url": {
|
| 125 |
+
"url": f"data:image/jpeg;base64,{image_to_base64(image_paths[img_idx])}"
|
| 126 |
+
}
|
| 127 |
+
})
|
| 128 |
+
pos = end
|
| 129 |
+
# 剩余文本
|
| 130 |
+
if pos < len(text):
|
| 131 |
+
sub_text = text[pos:]
|
| 132 |
+
if sub_text.strip():
|
| 133 |
+
content.append({"type": "text", "text": sub_text})
|
| 134 |
+
return content
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def format_convs_as_xml_image(convs: List[Dict]) -> List[Dict]:
|
| 138 |
+
"""
|
| 139 |
+
支持多模态convs:用 [IMGn] 占位符标记图片,并返回OpenAI兼容的 content 列表。
|
| 140 |
+
|
| 141 |
+
约定(尽量兼容现有数据):
|
| 142 |
+
- 单图:conv['img_path'] 或 conv['img']
|
| 143 |
+
- 多图:conv['image_paths'] / conv['img_paths'] / conv['imgs'] / conv['images'](list[str])
|
| 144 |
+
|
| 145 |
+
模式(自动):
|
| 146 |
+
- **嵌入式**:若 conv['conv_text'] 中包含 [IMGk],则 k 按“该 conv 的 image_paths(从1开始)”索引,
|
| 147 |
+
生成XML时会把这些 [IMGk] 重写为全局 [IMGn],并在最终 content 中按位置插入图片。
|
| 148 |
+
- **非嵌入式**:若 conv_text 不包含占位符,则会在 XML 中为该 conv 追加 <image>[IMGn]</image> 节点。
|
| 149 |
+
"""
|
| 150 |
+
logger.debug(f"Building multimodal XML input for {len(convs)} convs")
|
| 151 |
+
xml_convs = etree.Element("convs")
|
| 152 |
+
|
| 153 |
+
image_paths_flat: List[Union[str, None]] = []
|
| 154 |
+
img_counter = 0 # 全局图片序号,从1开始
|
| 155 |
+
|
| 156 |
+
def _extract_image_paths(conv: Dict) -> List[str]:
|
| 157 |
+
paths: List[str] = []
|
| 158 |
+
single = conv.get('img_path', conv.get('img', None))
|
| 159 |
+
if isinstance(single, str) and single:
|
| 160 |
+
paths.append(single)
|
| 161 |
+
for k in ("image_paths", "img_paths", "imgs", "images"):
|
| 162 |
+
v = conv.get(k, None)
|
| 163 |
+
if isinstance(v, list):
|
| 164 |
+
for p in v:
|
| 165 |
+
if isinstance(p, str) and p:
|
| 166 |
+
paths.append(p)
|
| 167 |
+
return paths
|
| 168 |
+
|
| 169 |
+
for conv in convs:
|
| 170 |
+
alias2age_map = conv['alias2age_map']
|
| 171 |
+
alias2age_text = ", ".join([f'user_{k} is {v}' for k, v in alias2age_map.items()])
|
| 172 |
+
xml_conv = etree.SubElement(xml_convs, "conv")
|
| 173 |
+
xml_conv_id = etree.SubElement(xml_conv, "conversation_id")
|
| 174 |
+
xml_conv_id.text = str(conv["conversation_id"])
|
| 175 |
+
xml_conv_text = etree.SubElement(xml_conv, "conv_text")
|
| 176 |
+
conv_text = conv["conv_text"]
|
| 177 |
+
xml_conv_ageinfo = etree.SubElement(xml_conv, "conv_ageinfo")
|
| 178 |
+
xml_conv_ageinfo.text = alias2age_text
|
| 179 |
+
xml_conv_region = etree.SubElement(xml_conv, "region")
|
| 180 |
+
xml_conv_region.text = conv["store_region"]
|
| 181 |
+
if conv.get("lang_fasttext", None) is not None:
|
| 182 |
+
xml_conv_language = etree.SubElement(xml_conv, "language")
|
| 183 |
+
xml_conv_language.text = conv["lang_fasttext"]
|
| 184 |
+
|
| 185 |
+
img_paths = _extract_image_paths(conv)
|
| 186 |
+
# If conv_text has placeholders, treat them as *local indices* into img_paths and rewrite to global ids.
|
| 187 |
+
placeholder_pattern = r'\[IMG(\d+)\]'
|
| 188 |
+
if isinstance(conv_text, str) and re.search(placeholder_pattern, conv_text):
|
| 189 |
+
local_to_global: Dict[int, int] = {}
|
| 190 |
+
|
| 191 |
+
def _replace(m: re.Match) -> str:
|
| 192 |
+
nonlocal img_counter
|
| 193 |
+
try:
|
| 194 |
+
local_idx = int(m.group(1)) - 1 # local placeholders are 1-based
|
| 195 |
+
except Exception:
|
| 196 |
+
return m.group(0)
|
| 197 |
+
if not (0 <= local_idx < len(img_paths)):
|
| 198 |
+
return m.group(0)
|
| 199 |
+
if local_idx not in local_to_global:
|
| 200 |
+
img_counter += 1
|
| 201 |
+
local_to_global[local_idx] = img_counter
|
| 202 |
+
image_paths_flat.append(img_paths[local_idx])
|
| 203 |
+
return f'[IMG{local_to_global[local_idx]}]'
|
| 204 |
+
|
| 205 |
+
xml_conv_text.text = re.sub(placeholder_pattern, _replace, conv_text)
|
| 206 |
+
else:
|
| 207 |
+
# Non-embedded: keep text unchanged and append <image> nodes
|
| 208 |
+
xml_conv_text.text = conv_text
|
| 209 |
+
# for p in img_paths:
|
| 210 |
+
# img_counter += 1
|
| 211 |
+
# image_paths_flat.append(p)
|
| 212 |
+
# xml_img = etree.SubElement(xml_conv, "image")
|
| 213 |
+
# xml_img.text = f'[IMG{img_counter}]'
|
| 214 |
+
|
| 215 |
+
xml_str = etree.tostring(xml_convs, pretty_print=True, encoding="utf-8").decode("utf-8")
|
| 216 |
+
logger.debug(f"Generated multimodal XML: \n{xml_str}")
|
| 217 |
+
return text_with_placeholders_to_content(xml_str, image_paths_flat)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def parse_unit_judgment(text_info: dict) -> int:
|
| 221 |
+
"""Parse the judgment from LLM response for units.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
text_info: dict
|
| 225 |
+
- text: Response text from LLM
|
| 226 |
+
- unit: unit the LLM is processing
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
int: 1 if Unit A > Unit B, -1 if Unit A < Unit B, 0 if equal
|
| 230 |
+
|
| 231 |
+
Raises:
|
| 232 |
+
ValueError: If judgment cannot be parsed
|
| 233 |
+
"""
|
| 234 |
+
text = text_info.get("text")
|
| 235 |
+
unit = text_info.get("unit", "comment")
|
| 236 |
+
m_res = re.search(r'<\s*result\s*>(.*?)<\s*/\s*result\s*>',
|
| 237 |
+
text, re.IGNORECASE | re.DOTALL)
|
| 238 |
+
if not m_res:
|
| 239 |
+
logger.error("No result block found in LLM response")
|
| 240 |
+
raise ValueError("No result block found")
|
| 241 |
+
result_body = m_res.group(1)
|
| 242 |
+
|
| 243 |
+
m_j = re.search(r'<\s*judgment\s*>(.*?)<\s*/\s*judgment\s*>',
|
| 244 |
+
result_body, re.IGNORECASE | re.DOTALL)
|
| 245 |
+
if not m_j:
|
| 246 |
+
logger.error("No judgment found in result block")
|
| 247 |
+
raise ValueError("No judgment found")
|
| 248 |
+
judgment_text = html.unescape(m_j.group(1).strip())
|
| 249 |
+
|
| 250 |
+
judgment_text = re.sub(r'\bFinal\s+Judgment\b', '',
|
| 251 |
+
judgment_text, flags=re.IGNORECASE).strip()
|
| 252 |
+
|
| 253 |
+
judgment_text = (judgment_text
|
| 254 |
+
.replace(":", ":")
|
| 255 |
+
.replace(">", ">")
|
| 256 |
+
.replace("<", "<")
|
| 257 |
+
.replace("=", "="))
|
| 258 |
+
|
| 259 |
+
# m_cmp = re.search(
|
| 260 |
+
# r'(?i)comment\s*([ab])\s*([<>]=?|==?)\s*comment\s*([ab])',
|
| 261 |
+
# judgment_text
|
| 262 |
+
# )
|
| 263 |
+
# Extracts content in between <judgement></judgement>
|
| 264 |
+
pattern = rf'(?i){re.escape(unit)}\s*([ab])\s*([<>]=?|==?)\s*{re.escape(unit)}\s*([ab])'
|
| 265 |
+
m_cmp = re.search(pattern, judgment_text)
|
| 266 |
+
if not m_cmp:
|
| 267 |
+
logger.error(f"Invalid judgment format: {judgment_text!r}")
|
| 268 |
+
raise ValueError(f"Invalid judgment format: {judgment_text!r}")
|
| 269 |
+
|
| 270 |
+
left, op, right = m_cmp.groups()
|
| 271 |
+
left, right = left.upper(), right.upper()
|
| 272 |
+
|
| 273 |
+
# Extract the violative message
|
| 274 |
+
violative_msg_dict = {}
|
| 275 |
+
violative_content = re.search(r'<violative_messages>(.*?)</violative_messages>', text, re.DOTALL)
|
| 276 |
+
if violative_content:
|
| 277 |
+
content = violative_content.group(1).strip()
|
| 278 |
+
# Extract tags and their contents (<A>m1,m2</A>)
|
| 279 |
+
matches = re.findall(r'<([A-Z])>(.*?)</\1>', content)
|
| 280 |
+
violative_msg_dict = {tag: val.strip().strip("<>[]\"").split(",") for tag, val in matches}
|
| 281 |
+
else:
|
| 282 |
+
logger.warning("No violative_messages found.")
|
| 283 |
+
|
| 284 |
+
label = -1
|
| 285 |
+
if op == '>':
|
| 286 |
+
label = 1 if left == 'A' else -1
|
| 287 |
+
elif op == '<':
|
| 288 |
+
label = -1 if left == 'A' else 1
|
| 289 |
+
elif op in ('=', '=='):
|
| 290 |
+
label = 0
|
| 291 |
+
else:
|
| 292 |
+
logger.error(f"Unsupported operator in judgment: {op!r}")
|
| 293 |
+
raise ValueError(f"Unsupported operator: {op!r}")
|
| 294 |
+
|
| 295 |
+
return label, violative_msg_dict
|
| 296 |
+
|
| 297 |
+
class PairwiseComparison:
|
| 298 |
+
"""
|
| 299 |
+
Handles pairwise comparisons between items using LLM.
|
| 300 |
+
Manages comparison results, counts, and statistics.
|
| 301 |
+
"""
|
| 302 |
+
|
| 303 |
+
def __init__(self,
|
| 304 |
+
llm: BaseLLM,
|
| 305 |
+
prompt_templates: Dict[str, str],
|
| 306 |
+
format_items: Callable[[List[Dict]], Any],
|
| 307 |
+
parse_judgment: Callable[[str], int],
|
| 308 |
+
data_dir: str,
|
| 309 |
+
es_index: str,
|
| 310 |
+
es_psm: str = "byte.es.ranking_moderation_cmt.service.my",
|
| 311 |
+
business: str = "comment",
|
| 312 |
+
max_comparisons_per_pair: int = 3,
|
| 313 |
+
max_workers: int = 4,
|
| 314 |
+
max_retries: int = 8,
|
| 315 |
+
initial_retry_delay: float = 2.0,
|
| 316 |
+
max_backups: int = 3,
|
| 317 |
+
detect_msg_violations: bool = True,
|
| 318 |
+
log_gpt_io: bool = False,
|
| 319 |
+
log_sample_rate: float = 0.01,
|
| 320 |
+
local_cache_path: Union[str, Path, None] = None,
|
| 321 |
+
local_cache_enabled: bool = False):
|
| 322 |
+
"""Initialize pairwise comparison handler.
|
| 323 |
+
|
| 324 |
+
Args:
|
| 325 |
+
llm: Language model for pairwise comparisons
|
| 326 |
+
prompt_templates: Dictionary of prompt templates
|
| 327 |
+
format_items: Function to format items into LLM input. Can return either str or List[Dict]
|
| 328 |
+
parse_judgment: Function to parse LLM response into judgment
|
| 329 |
+
data_dir: Directory to store state files
|
| 330 |
+
max_comparisons_per_pair: Maximum number of comparisons for each pair
|
| 331 |
+
max_workers: Maximum number of worker processes
|
| 332 |
+
max_retries: Maximum number of retry attempts for failed comparisons
|
| 333 |
+
initial_retry_delay: Initial delay in seconds before first retry
|
| 334 |
+
max_backups: Maximum number of backup files to keep (default: 3)
|
| 335 |
+
"""
|
| 336 |
+
self.llm = llm
|
| 337 |
+
self.prompt_templates = prompt_templates
|
| 338 |
+
self.format_items = format_items
|
| 339 |
+
self.parse_judgment = parse_judgment
|
| 340 |
+
self.max_comparisons_per_pair = max_comparisons_per_pair
|
| 341 |
+
self.max_workers = max_workers
|
| 342 |
+
self.max_retries = max_retries
|
| 343 |
+
self.initial_retry_delay = initial_retry_delay
|
| 344 |
+
self.max_backups = max_backups
|
| 345 |
+
self.data_dir = Path(data_dir)
|
| 346 |
+
self.es_index = es_index
|
| 347 |
+
if self.es_index == 'None':
|
| 348 |
+
self.es_index = None
|
| 349 |
+
if local_cache_path == 'None':
|
| 350 |
+
local_cache_path = None
|
| 351 |
+
self.data_dir.mkdir(parents=True, exist_ok=True)
|
| 352 |
+
self.local_cache_enabled = bool(local_cache_enabled or local_cache_path)
|
| 353 |
+
self.local_cache_path = None
|
| 354 |
+
if self.local_cache_enabled:
|
| 355 |
+
if local_cache_path is None or local_cache_path is True:
|
| 356 |
+
self.local_cache_path = self.data_dir / "pairwise_comparisons.jsonl"
|
| 357 |
+
else:
|
| 358 |
+
self.local_cache_path = Path(local_cache_path)
|
| 359 |
+
self.local_cache_path.parent.mkdir(parents=True, exist_ok=True)
|
| 360 |
+
self._local_cache_loaded = False
|
| 361 |
+
|
| 362 |
+
if self.es_index is not None:
|
| 363 |
+
self.client = bytedes.make_client(psm=es_psm, cluster="data",scheme="https",
|
| 364 |
+
verify_certs=False, use_ssl=True, ssl_show_warn=False, maxsize=50)
|
| 365 |
+
else:
|
| 366 |
+
self.client = None
|
| 367 |
+
|
| 368 |
+
if business.lower() == "comment":
|
| 369 |
+
self.process_unit = "comment"
|
| 370 |
+
elif business.lower() in ["dm", "dm_mm"]:
|
| 371 |
+
self.process_unit = "conversation"
|
| 372 |
+
else:
|
| 373 |
+
raise NotImplementedError(f"Ranking moderation for business {business} is not implemented!")
|
| 374 |
+
|
| 375 |
+
# Initialize state with thread-safe data structures
|
| 376 |
+
self._cache_lock = Lock()
|
| 377 |
+
self.comparison_results = defaultdict(list) # Store pairwise comparison results
|
| 378 |
+
|
| 379 |
+
self.detect_msg_violations = detect_msg_violations
|
| 380 |
+
self.log_gpt_io = log_gpt_io
|
| 381 |
+
self.log_sample_rate = log_sample_rate
|
| 382 |
+
|
| 383 |
+
def get_state_file_path(self, filename: str) -> Path:
|
| 384 |
+
"""Get the full path for a state file.
|
| 385 |
+
|
| 386 |
+
Args:
|
| 387 |
+
filename: Name of the state file
|
| 388 |
+
|
| 389 |
+
Returns:
|
| 390 |
+
Path: Full path to the state file
|
| 391 |
+
"""
|
| 392 |
+
return self.data_dir / filename
|
| 393 |
+
|
| 394 |
+
def get_pair_key(self, item_id1: str, item_id2: str) -> str:
|
| 395 |
+
"""Get unique key for a pair of items.
|
| 396 |
+
|
| 397 |
+
Args:
|
| 398 |
+
item_id1: First item ID
|
| 399 |
+
item_id2: Second item ID
|
| 400 |
+
|
| 401 |
+
Returns:
|
| 402 |
+
str: Unique key for the pair
|
| 403 |
+
"""
|
| 404 |
+
item_id1, item_id2 = str(item_id1), str(item_id2)
|
| 405 |
+
return f"{min(item_id1, item_id2)}_{max(item_id1, item_id2)}"
|
| 406 |
+
|
| 407 |
+
def get_ordered_pair(self, item_id1: str, item_id2: str) -> Tuple[str, str]:
|
| 408 |
+
"""Get ordered pair of item IDs.
|
| 409 |
+
|
| 410 |
+
Args:
|
| 411 |
+
item_id1: First item ID
|
| 412 |
+
item_id2: Second item ID
|
| 413 |
+
|
| 414 |
+
Returns:
|
| 415 |
+
Tuple[str, str]: Ordered pair of item IDs (min, max)
|
| 416 |
+
"""
|
| 417 |
+
return min(item_id1, item_id2), max(item_id1, item_id2)
|
| 418 |
+
|
| 419 |
+
def get_compare_result_from_es(self, pair_key: str) -> str:
|
| 420 |
+
# read from cache
|
| 421 |
+
with self._cache_lock:
|
| 422 |
+
if pair_key in self.comparison_results:
|
| 423 |
+
return self.comparison_results[pair_key]
|
| 424 |
+
|
| 425 |
+
if self.es_index is None:
|
| 426 |
+
if self.local_cache_enabled and not self._local_cache_loaded:
|
| 427 |
+
item_ids = None
|
| 428 |
+
try:
|
| 429 |
+
id1, id2 = pair_key.split("_", 1)
|
| 430 |
+
item_ids = [id1, id2]
|
| 431 |
+
except Exception:
|
| 432 |
+
item_ids = None
|
| 433 |
+
self.load_data_to_cache_from_local(item_ids, load_detail=True)
|
| 434 |
+
with self._cache_lock:
|
| 435 |
+
if pair_key in self.comparison_results:
|
| 436 |
+
return self.comparison_results[pair_key]
|
| 437 |
+
return []
|
| 438 |
+
trial = 0
|
| 439 |
+
current_delay = self.initial_retry_delay
|
| 440 |
+
|
| 441 |
+
while trial < self.max_retries:
|
| 442 |
+
try:
|
| 443 |
+
query_body = {
|
| 444 |
+
"query": {
|
| 445 |
+
"term": {
|
| 446 |
+
"pair_key": {
|
| 447 |
+
"value": pair_key
|
| 448 |
+
}
|
| 449 |
+
}
|
| 450 |
+
}
|
| 451 |
+
}
|
| 452 |
+
result = self.client.search(index=self.es_index, body=query_body, size=200)
|
| 453 |
+
compare_result = []
|
| 454 |
+
for hit in result['hits']['hits']:
|
| 455 |
+
r = {
|
| 456 |
+
"judgment": hit['_source']['judgment'],
|
| 457 |
+
"raw_response": hit['_source']['raw_response'],
|
| 458 |
+
"timestamp": hit['_source']["timestamp"],
|
| 459 |
+
"item_id_a": hit['_source']['item_id_a'],
|
| 460 |
+
"item_id_b": hit['_source']['item_id_b'],
|
| 461 |
+
'ordered_ids': self.get_ordered_pair(hit['_source']['item_id_a'], hit['_source']['item_id_b']),
|
| 462 |
+
'original_ids': (hit['_source']['item_id_a'], hit['_source']['item_id_b'])
|
| 463 |
+
}
|
| 464 |
+
if self.detect_msg_violations:
|
| 465 |
+
r["violative_msg_map_str"] = hit['_source']['violative_msg_map_str']
|
| 466 |
+
|
| 467 |
+
compare_result.append(r)
|
| 468 |
+
with self._cache_lock:
|
| 469 |
+
self.comparison_results[pair_key] = compare_result
|
| 470 |
+
|
| 471 |
+
return compare_result
|
| 472 |
+
|
| 473 |
+
except Exception as e:
|
| 474 |
+
trial += 1
|
| 475 |
+
if trial == self.max_retries:
|
| 476 |
+
logger.error(f"Failed to get data from es after {self.max_retries} attempts: {str(e)}")
|
| 477 |
+
raise
|
| 478 |
+
logger.warning(f"Failed to get data from es, Request failed (attempt {trial}/{self.max_retries}). "
|
| 479 |
+
f"Retrying in {current_delay} seconds. Error: {str(e)}")
|
| 480 |
+
time.sleep(current_delay)
|
| 481 |
+
current_delay *= 2
|
| 482 |
+
|
| 483 |
+
def load_data_to_cache_from_es(self, item_ids: List[str], load_detail=True):
|
| 484 |
+
"""Load all data from ES."""
|
| 485 |
+
# read from cache
|
| 486 |
+
if self.es_index is None:
|
| 487 |
+
return self.load_data_to_cache_from_local(item_ids, load_detail=load_detail)
|
| 488 |
+
def chunk_list(lst, chunk_size):
|
| 489 |
+
for i in range(0, len(lst), chunk_size):
|
| 490 |
+
yield lst[i:i + chunk_size]
|
| 491 |
+
|
| 492 |
+
comparison_results_temp = {}
|
| 493 |
+
|
| 494 |
+
scroll = "2m"
|
| 495 |
+
batch_size = 500
|
| 496 |
+
|
| 497 |
+
_source = ["pair_key", "judgment", "timestamp", "item_id_a", "item_id_b"]
|
| 498 |
+
if self.detect_msg_violations:
|
| 499 |
+
_source.append("violative_msg_map_str")
|
| 500 |
+
|
| 501 |
+
if load_detail:
|
| 502 |
+
_source.append("raw_response")
|
| 503 |
+
|
| 504 |
+
for id_batch in chunk_list(item_ids, batch_size):
|
| 505 |
+
query = {
|
| 506 |
+
"_source": _source,
|
| 507 |
+
"query": {
|
| 508 |
+
"bool": {
|
| 509 |
+
"should": [
|
| 510 |
+
{ "terms": { "item_id_a.keyword": id_batch }},
|
| 511 |
+
{ "terms": { "item_id_b.keyword": id_batch }}
|
| 512 |
+
],
|
| 513 |
+
"minimum_should_match": 1
|
| 514 |
+
}
|
| 515 |
+
}
|
| 516 |
+
}
|
| 517 |
+
page = self.client.search(index=self.es_index, body=query, scroll=scroll, size=5000)
|
| 518 |
+
sid = page["_scroll_id"]
|
| 519 |
+
scroll_size = len(page["hits"]["hits"])
|
| 520 |
+
while scroll_size > 0:
|
| 521 |
+
for hit in page["hits"]["hits"]:
|
| 522 |
+
pair_key = hit['_source']['pair_key']
|
| 523 |
+
r = {
|
| 524 |
+
"judgment": hit['_source']['judgment'],
|
| 525 |
+
"timestamp": hit['_source']["timestamp"],
|
| 526 |
+
"item_id_a": hit['_source']['item_id_a'],
|
| 527 |
+
"item_id_b": hit['_source']['item_id_b'],
|
| 528 |
+
'ordered_ids': self.get_ordered_pair(hit['_source']['item_id_a'], hit['_source']['item_id_b']),
|
| 529 |
+
'original_ids': (hit['_source']['item_id_a'], hit['_source']['item_id_b'])
|
| 530 |
+
}
|
| 531 |
+
if self.detect_msg_violations:
|
| 532 |
+
r["violative_msg_map_str"] = hit['_source']['violative_msg_map_str']
|
| 533 |
+
|
| 534 |
+
if 'raw_response' in hit['_source']:
|
| 535 |
+
r['raw_response'] = hit['_source']['raw_response']
|
| 536 |
+
exists = False
|
| 537 |
+
with self._cache_lock:
|
| 538 |
+
if pair_key in self.comparison_results:
|
| 539 |
+
for comparison_result in self.comparison_results[pair_key]:
|
| 540 |
+
if r['timestamp'] == comparison_result['timestamp']:
|
| 541 |
+
exists = True
|
| 542 |
+
break
|
| 543 |
+
if not exists:
|
| 544 |
+
if pair_key not in self.comparison_results:
|
| 545 |
+
self.comparison_results[pair_key] = []
|
| 546 |
+
self.comparison_results[pair_key].append(r)
|
| 547 |
+
# save a temp relate to this batch
|
| 548 |
+
exists = False
|
| 549 |
+
if pair_key in comparison_results_temp:
|
| 550 |
+
for comparison_result in comparison_results_temp[pair_key]:
|
| 551 |
+
if r['timestamp'] == comparison_result['timestamp']:
|
| 552 |
+
exists = True
|
| 553 |
+
break
|
| 554 |
+
if not exists:
|
| 555 |
+
if pair_key not in comparison_results_temp:
|
| 556 |
+
comparison_results_temp[pair_key] = []
|
| 557 |
+
comparison_results_temp[pair_key].append(r)
|
| 558 |
+
page = self.client.scroll(scroll_id=sid, scroll=scroll)
|
| 559 |
+
sid = page["_scroll_id"]
|
| 560 |
+
scroll_size = len(page["hits"]["hits"])
|
| 561 |
+
return comparison_results_temp
|
| 562 |
+
|
| 563 |
+
def load_data_to_cache_from_local(self, item_ids: Union[List[str], None], load_detail=True):
|
| 564 |
+
if not self.local_cache_enabled or self.local_cache_path is None:
|
| 565 |
+
return {}
|
| 566 |
+
if not self.local_cache_path.exists():
|
| 567 |
+
self._local_cache_loaded = True
|
| 568 |
+
return {}
|
| 569 |
+
item_id_set = None
|
| 570 |
+
if item_ids is not None:
|
| 571 |
+
item_id_set = set(str(x) for x in item_ids)
|
| 572 |
+
comparison_results_temp = {}
|
| 573 |
+
with self._cache_lock:
|
| 574 |
+
self._local_cache_loaded = True
|
| 575 |
+
with open(self.local_cache_path, "r", encoding="utf-8") as f:
|
| 576 |
+
for line in f:
|
| 577 |
+
line = line.strip()
|
| 578 |
+
if not line:
|
| 579 |
+
continue
|
| 580 |
+
try:
|
| 581 |
+
doc = json.loads(line)
|
| 582 |
+
except Exception:
|
| 583 |
+
continue
|
| 584 |
+
if "item_id_a" not in doc or "item_id_b" not in doc:
|
| 585 |
+
continue
|
| 586 |
+
item_id_a = str(doc["item_id_a"])
|
| 587 |
+
item_id_b = str(doc["item_id_b"])
|
| 588 |
+
if item_id_set is not None and item_id_a not in item_id_set and item_id_b not in item_id_set:
|
| 589 |
+
continue
|
| 590 |
+
pair_key = doc.get("pair_key", self.get_pair_key(item_id_a, item_id_b))
|
| 591 |
+
r = {
|
| 592 |
+
"judgment": doc.get("judgment"),
|
| 593 |
+
"timestamp": doc.get("timestamp"),
|
| 594 |
+
"item_id_a": item_id_a,
|
| 595 |
+
"item_id_b": item_id_b,
|
| 596 |
+
"ordered_ids": self.get_ordered_pair(item_id_a, item_id_b),
|
| 597 |
+
"original_ids": (item_id_a, item_id_b)
|
| 598 |
+
}
|
| 599 |
+
if self.detect_msg_violations and "violative_msg_map_str" in doc:
|
| 600 |
+
r["violative_msg_map_str"] = doc["violative_msg_map_str"]
|
| 601 |
+
if load_detail and "raw_response" in doc:
|
| 602 |
+
r["raw_response"] = doc["raw_response"]
|
| 603 |
+
exists = False
|
| 604 |
+
with self._cache_lock:
|
| 605 |
+
if pair_key in self.comparison_results:
|
| 606 |
+
for comparison_result in self.comparison_results[pair_key]:
|
| 607 |
+
if r["timestamp"] == comparison_result.get("timestamp"):
|
| 608 |
+
exists = True
|
| 609 |
+
break
|
| 610 |
+
if not exists:
|
| 611 |
+
if pair_key not in self.comparison_results:
|
| 612 |
+
self.comparison_results[pair_key] = []
|
| 613 |
+
self.comparison_results[pair_key].append(r)
|
| 614 |
+
exists = False
|
| 615 |
+
if pair_key in comparison_results_temp:
|
| 616 |
+
for comparison_result in comparison_results_temp[pair_key]:
|
| 617 |
+
if r["timestamp"] == comparison_result.get("timestamp"):
|
| 618 |
+
exists = True
|
| 619 |
+
break
|
| 620 |
+
if not exists:
|
| 621 |
+
if pair_key not in comparison_results_temp:
|
| 622 |
+
comparison_results_temp[pair_key] = []
|
| 623 |
+
comparison_results_temp[pair_key].append(r)
|
| 624 |
+
return comparison_results_temp
|
| 625 |
+
|
| 626 |
+
def write_compare_result_to_es(self, pair_key: str, comparison_result):
|
| 627 |
+
|
| 628 |
+
trial = 0
|
| 629 |
+
current_delay = self.initial_retry_delay
|
| 630 |
+
|
| 631 |
+
while trial < self.max_retries:
|
| 632 |
+
try:
|
| 633 |
+
doc = {}
|
| 634 |
+
doc['pair_key'] = pair_key
|
| 635 |
+
doc['judgment'] = comparison_result['judgment']
|
| 636 |
+
doc['raw_response'] = comparison_result['raw_response']
|
| 637 |
+
doc['timestamp'] = comparison_result['timestamp']
|
| 638 |
+
doc['item_id_a'] = comparison_result['original_ids'][0]
|
| 639 |
+
doc['item_id_b'] = comparison_result['original_ids'][1]
|
| 640 |
+
if self.detect_msg_violations:
|
| 641 |
+
doc['violative_msg_map_str'] = comparison_result['violative_msg_map_str']
|
| 642 |
+
|
| 643 |
+
with self._cache_lock:
|
| 644 |
+
if pair_key in self.comparison_results:
|
| 645 |
+
self.comparison_results[pair_key].append(comparison_result)
|
| 646 |
+
else:
|
| 647 |
+
self.comparison_results[pair_key] = [comparison_result]
|
| 648 |
+
if self.local_cache_enabled and self.local_cache_path is not None:
|
| 649 |
+
with open(self.local_cache_path, "a", encoding="utf-8") as f:
|
| 650 |
+
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
|
| 651 |
+
if self.es_index is not None:
|
| 652 |
+
return self.client.index(index=self.es_index, body=doc)
|
| 653 |
+
else:
|
| 654 |
+
return
|
| 655 |
+
|
| 656 |
+
except Exception as e:
|
| 657 |
+
trial += 1
|
| 658 |
+
if trial == self.max_retries:
|
| 659 |
+
logger.error(f"Failed to get data from es after {self.max_retries} attempts: {str(e)}")
|
| 660 |
+
raise
|
| 661 |
+
logger.warning(f"Failed to get data from es, Request failed (attempt {trial}/{self.max_retries}). "
|
| 662 |
+
f"Retrying in {current_delay} seconds. Error: {str(e)}")
|
| 663 |
+
time.sleep(current_delay)
|
| 664 |
+
current_delay *= 2
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
def compare_single_pair(self, pair: Tuple[Dict, Dict], random_swap: bool = True, use_cache: bool = True) -> Dict[str, str]:
|
| 668 |
+
"""Compare a single pair of items and store the result.
|
| 669 |
+
|
| 670 |
+
Args:
|
| 671 |
+
pair: Tuple of (item1, item2)
|
| 672 |
+
random_swap: Whether to randomly swap the order of items.
|
| 673 |
+
If False, keep the original order.
|
| 674 |
+
Default is True for backward compatibility.
|
| 675 |
+
use_cache: Whether to use cached comparison results if available.
|
| 676 |
+
If False, always perform new comparison.
|
| 677 |
+
Default is True for backward compatibility.
|
| 678 |
+
"""
|
| 679 |
+
item1, item2 = pair
|
| 680 |
+
# Only swap items if random_swap is True
|
| 681 |
+
if random_swap and random.random() < 0.5:
|
| 682 |
+
item1, item2 = item2, item1
|
| 683 |
+
|
| 684 |
+
item_id1 = str(item1['item_id'])
|
| 685 |
+
item_id2 = str(item2['item_id'])
|
| 686 |
+
pair_key = self.get_pair_key(item_id1, item_id2)
|
| 687 |
+
ordered_id1, ordered_id2 = self.get_ordered_pair(item_id1, item_id2)
|
| 688 |
+
|
| 689 |
+
# get result from es
|
| 690 |
+
if use_cache:
|
| 691 |
+
compare_result = self.get_compare_result_from_es(pair_key)
|
| 692 |
+
if len(compare_result) >= self.max_comparisons_per_pair:
|
| 693 |
+
return random.choice(compare_result)
|
| 694 |
+
|
| 695 |
+
# Format items into LLM input
|
| 696 |
+
formatted_input = self.format_items([item1, item2])
|
| 697 |
+
|
| 698 |
+
messages = [
|
| 699 |
+
{"role": "system", "content": self.prompt_templates['system_prompt']},
|
| 700 |
+
{"role": "user", "content": formatted_input}
|
| 701 |
+
]
|
| 702 |
+
|
| 703 |
+
trial = 0
|
| 704 |
+
current_delay = self.initial_retry_delay
|
| 705 |
+
|
| 706 |
+
while trial < self.max_retries:
|
| 707 |
+
response = None
|
| 708 |
+
try:
|
| 709 |
+
response = self.llm.chat_completion(messages=messages, max_tokens=4000, temperature=0.0)
|
| 710 |
+
result = response['choices'][0]['message']['content']
|
| 711 |
+
result_info = {
|
| 712 |
+
"text": result,
|
| 713 |
+
"unit": self.process_unit
|
| 714 |
+
}
|
| 715 |
+
judgment, violative_msg_dict = self.parse_judgment(result_info)
|
| 716 |
+
|
| 717 |
+
comparison_result = {
|
| 718 |
+
'judgment': judgment,
|
| 719 |
+
'raw_response': result,
|
| 720 |
+
'timestamp': time.time(),
|
| 721 |
+
'ordered_ids': (ordered_id1, ordered_id2),
|
| 722 |
+
'original_ids': (item_id1, item_id2),
|
| 723 |
+
"item_id_a": str(item_id1),
|
| 724 |
+
"item_id_b": str(item_id2)
|
| 725 |
+
}
|
| 726 |
+
if self.detect_msg_violations:
|
| 727 |
+
# e.g. {"a": [m0,m1], "b": [m0]}
|
| 728 |
+
comparison_result['violative_msg_map_str'] = json.dumps(violative_msg_dict)
|
| 729 |
+
|
| 730 |
+
if self.detect_msg_violations and self.log_gpt_io and random.random()<self.log_sample_rate:
|
| 731 |
+
with self._cache_lock:
|
| 732 |
+
with open(self.data_dir / "tmp_gpt_io.jsonl", "a") as f:
|
| 733 |
+
if isinstance(formatted_input, str):
|
| 734 |
+
formatted_input_for_regex = formatted_input
|
| 735 |
+
else:
|
| 736 |
+
# multimodal content: concatenate all text blocks
|
| 737 |
+
formatted_input_for_regex = "".join(
|
| 738 |
+
c.get("text", "")
|
| 739 |
+
for c in formatted_input
|
| 740 |
+
if isinstance(c, dict) and c.get("type") == "text"
|
| 741 |
+
)
|
| 742 |
+
_convs = re.findall(r'<conv_text>(.*?)</conv_text>', formatted_input_for_regex.strip(), re.DOTALL)
|
| 743 |
+
_convs = [html.unescape(c).strip() for c in _convs]
|
| 744 |
+
_age_info = re.findall(r'<conv_ageinfo>(.*?)</conv_ageinfo>', formatted_input_for_regex, re.DOTALL)
|
| 745 |
+
_data = {
|
| 746 |
+
'timestamp': time.time(),
|
| 747 |
+
"conversations": [{"conv_text": c, "conv_ageinfo": a} for c,a in zip(_convs, _age_info)],
|
| 748 |
+
"judgment": judgment,
|
| 749 |
+
"violative_msg_map_str": violative_msg_dict
|
| 750 |
+
}
|
| 751 |
+
f.write(json.dumps(_data, ensure_ascii=False)+"\n")
|
| 752 |
+
|
| 753 |
+
self.write_compare_result_to_es(pair_key, comparison_result)
|
| 754 |
+
return comparison_result
|
| 755 |
+
except Exception as e:
|
| 756 |
+
trial += 1
|
| 757 |
+
if str(e) == "No result block found" and trial == 2:
|
| 758 |
+
logger.error(f"parse error, only try 2 times")
|
| 759 |
+
return
|
| 760 |
+
if trial == self.max_retries:
|
| 761 |
+
logger.error(f"Failed to compare pair after {self.max_retries} attempts: {str(e)}")
|
| 762 |
+
raise
|
| 763 |
+
logger.warning(f"Request failed (attempt {trial}/{self.max_retries}). "
|
| 764 |
+
f"Retrying in {current_delay} seconds. Error: {str(e)}")
|
| 765 |
+
if response is not None:
|
| 766 |
+
logger.warning(f"Response: {response}")
|
| 767 |
+
time.sleep(current_delay)
|
| 768 |
+
current_delay *= 2
|
| 769 |
+
|
| 770 |
+
def compare_pairs(self, pairs: List[Tuple[Dict, Dict]], random_swap: bool = True, use_cache: bool = True, use_tqdm: bool = True) -> List[Tuple[Dict, Dict]]:
|
| 771 |
+
"""Compare multiple pairs of items in parallel.
|
| 772 |
+
|
| 773 |
+
Args:
|
| 774 |
+
pairs: List of (item1, item2) tuples to compare
|
| 775 |
+
random_swap: Whether to randomly swap the order of items in each pair.
|
| 776 |
+
If False, keep the original order.
|
| 777 |
+
Default is True for backward compatibility.
|
| 778 |
+
use_cache: Whether to use cached comparison results if available.
|
| 779 |
+
If False, always perform new comparison.
|
| 780 |
+
Default is True for backward compatibility.
|
| 781 |
+
|
| 782 |
+
Returns:
|
| 783 |
+
List of successfully compared pairs
|
| 784 |
+
"""
|
| 785 |
+
# Create a new process pool for each batch of comparisons
|
| 786 |
+
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
| 787 |
+
futures = []
|
| 788 |
+
for pair in pairs:
|
| 789 |
+
future = executor.submit(self.compare_single_pair, pair, random_swap=random_swap, use_cache=use_cache)
|
| 790 |
+
futures.append((future, pair))
|
| 791 |
+
|
| 792 |
+
# Get successfully compared pairs
|
| 793 |
+
successful_pairs = []
|
| 794 |
+
for future, pair in tqdm(futures, desc="Comparing pairs", total=len(futures), leave=False, disable=not use_tqdm):
|
| 795 |
+
try:
|
| 796 |
+
result = future.result()
|
| 797 |
+
if result: # If comparison was successful
|
| 798 |
+
successful_pairs.append(pair)
|
| 799 |
+
except Exception as e:
|
| 800 |
+
logger.error(f"Error comparing pair: {str(e)}")
|
| 801 |
+
|
| 802 |
+
return successful_pairs
|
| 803 |
+
|
| 804 |
+
def understand_by_pairs(self, item: Dict, compare_results: List[Dict]):
|
| 805 |
+
from lxml import etree
|
| 806 |
+
import random
|
| 807 |
+
rs_high, rs_low = [], []
|
| 808 |
+
target_id = str(item["item_id"])
|
| 809 |
+
for pair_key, r in compare_results.items():
|
| 810 |
+
r = r[0]
|
| 811 |
+
a, b, j = r["item_id_a"], r["item_id_b"], r["judgment"]
|
| 812 |
+
if target_id == a and j >= 0:
|
| 813 |
+
rs_low.append(r)
|
| 814 |
+
elif target_id == b and j < 0:
|
| 815 |
+
rs_low.append(r)
|
| 816 |
+
elif target_id == a and j < 0:
|
| 817 |
+
rs_high.append(r)
|
| 818 |
+
elif target_id == b and j >= 0:
|
| 819 |
+
rs_high.append(r)
|
| 820 |
+
random.shuffle(rs_high)
|
| 821 |
+
random.shuffle(rs_low)
|
| 822 |
+
examples = rs_high[:2] + rs_low[:2]
|
| 823 |
+
|
| 824 |
+
root = etree.Element("analysis_input")
|
| 825 |
+
|
| 826 |
+
target_block = etree.SubElement(root, "target_conversation")
|
| 827 |
+
xml_target = etree.fromstring(
|
| 828 |
+
self.format_items([item]).encode("utf-8")
|
| 829 |
+
)
|
| 830 |
+
target_block.append(xml_target)
|
| 831 |
+
|
| 832 |
+
comps_block = etree.SubElement(root, "comparisons")
|
| 833 |
+
for ex in examples:
|
| 834 |
+
# try:
|
| 835 |
+
# comp_xml = etree.fromstring(ex["raw_response"].encode("utf-8"))
|
| 836 |
+
# wrapped = etree.SubElement(comps_block, "comparison")
|
| 837 |
+
# wrapped.append(comp_xml)
|
| 838 |
+
# except Exception:
|
| 839 |
+
txt = etree.SubElement(comps_block, "raw_response")
|
| 840 |
+
txt.text = etree.CDATA(ex["raw_response"])
|
| 841 |
+
|
| 842 |
+
xml_input = etree.tostring(root, pretty_print=True,
|
| 843 |
+
encoding="utf-8").decode("utf-8")
|
| 844 |
+
|
| 845 |
+
system_prompt = """You are an expert in content safety and moderation.
|
| 846 |
+
|
| 847 |
+
You will receive a set of pairwise comparison results between a **target conversation** and multiple other conversations. Your primary goal is to clearly inform users about the **exact meaning and severity** of the target conversation based solely on these comparisons. Avoid any external assumptions or information.
|
| 848 |
+
|
| 849 |
+
When crafting your analysis, follow these guidelines:
|
| 850 |
+
|
| 851 |
+
* Clearly interpret the **exact meaning** of the target conversation, thoroughly and explicitly derived from the provided comparisons. Provide sufficient details so users clearly grasp the context, intention, and potential implications of the conversation.
|
| 852 |
+
* Identify explicitly which conversations the target conversation is **more severe than** and which ones it is **less severe than**. Clearly quote or paraphrase each comparison conversation, and provide explicit reasons derived from comparison data.
|
| 853 |
+
* Ensure all explanations rely **only** on the provided comparison data—**do not speculate or infer beyond what is explicitly given**.
|
| 854 |
+
* Structure your response clearly and flexibly so users immediately understand:
|
| 855 |
+
|
| 856 |
+
* The precise and detailed meaning of the target conversation.
|
| 857 |
+
* The explicit reasons behind the conversation's classification as severe or non-severe.
|
| 858 |
+
* Specific examples of comparisons supporting your severity assessments, clearly quoting or paraphrasing each messages in the conversation.
|
| 859 |
+
|
| 860 |
+
Your output should be in **Markdown format**, providing clear bilingual summaries in **English and Chinese** to ensure consistency for bilingual users.
|
| 861 |
+
|
| 862 |
+
*This detailed analysis aims to provide users with a clear and explicit understanding of the exact meaning of the conversation, its severity, and concrete evidence from comparative conversation or messages.* make sure your response can achieve this
|
| 863 |
+
|
| 864 |
+
|
| 865 |
+
---
|
| 866 |
+
|
| 867 |
+
**Detailed Output Example:**
|
| 868 |
+
|
| 869 |
+
## 📝 Conversation Analysis
|
| 870 |
+
|
| 871 |
+
### 🔹 English
|
| 872 |
+
|
| 873 |
+
**Direct translation:**
|
| 874 |
+
Translate this conversation into English. If the conversation is in English, simply copy the conversation as the output.
|
| 875 |
+
- Maintain all special tags message ids e.g. <m1> and user_ids e.g. user_0
|
| 876 |
+
- Do not remove existing new line characters
|
| 877 |
+
- Explain any cultural references / slang that could be unsafe or harmful in line e.g. user_0: Show me your "taye" [Note: "taye" means butt/ass in Algerian slang]
|
| 878 |
+
- Provide cultural context (within brackets inline - e.g. [Note: xxx]) when the translation in English sounds unnatural/incoherent
|
| 879 |
+
- Keep the explanations brief and concise (Keep a single note to a maximum of 10 words).
|
| 880 |
+
|
| 881 |
+
**Interpreted Meaning:**
|
| 882 |
+
The conversation explicitly expresses [detailed interpretation explicitly based on provided comparisons, clearly explaining context, intent, and implications].
|
| 883 |
+
|
| 884 |
+
**Severity Assessment:**
|
| 885 |
+
|
| 886 |
+
* **More Severe Than:**
|
| 887 |
+
|
| 888 |
+
* "Exact or paraphrased conversation A": because \[explicit reason from comparison clearly indicating severity difference].
|
| 889 |
+
* "Exact or paraphrased conversation B": because \[explicit reason from comparison clearly indicating severity difference].
|
| 890 |
+
|
| 891 |
+
* **Less Severe Than:**
|
| 892 |
+
|
| 893 |
+
* "Exact or paraphrased conversation C": because \[explicit reason from comparison clearly indicating severity difference].
|
| 894 |
+
|
| 895 |
+
---
|
| 896 |
+
|
| 897 |
+
### 🔸 中文
|
| 898 |
+
|
| 899 |
+
**直接翻译:**
|
| 900 |
+
将此对话翻译成中文。如果对话是中文,只需复制对话作为输出即可。
|
| 901 |
+
- 保留所有特殊标签消息 ID,例如 <m1> 和用户 ID,例如 user_0
|
| 902 |
+
- 不要删除现有的换行符
|
| 903 |
+
- 解释任何可能不安全或有害的文化指涉/俚语,例如 user_0:给我看看你的“taye”[注:“taye”在阿尔及利亚俚语中是屁股的意思]
|
| 904 |
+
- 如果英语翻译听起来不自然/语无伦次,请提供文化背景(在括号内,例如 [注:xxx])。
|
| 905 |
+
- 解释请简洁明了(每条注释最多 10 个字)。
|
| 906 |
+
|
| 907 |
+
**解释性含义**:
|
| 908 |
+
对话明确表达了\[基于提供的比较结果的详细解释,清晰地解释了上下文、意图和含义]。
|
| 909 |
+
|
| 910 |
+
**严重程度评估**:
|
| 911 |
+
|
| 912 |
+
* **严重程度高于**:
|
| 913 |
+
|
| 914 |
+
* “对话 A 完全一致或转述”:因为\[比较结果的明确原因清楚地表明了严重程度的差异]。
|
| 915 |
+
* “对话 B 完全一致或转述”:因为\[比较结果的明确原因清楚地表明了严重程度的差异]。
|
| 916 |
+
|
| 917 |
+
* **严重程度低于**:
|
| 918 |
+
|
| 919 |
+
* “对话 C 完全一致或转述”:因为\[比较结果的明确原因清楚地表明了严重程度的差异]。
|
| 920 |
+
|
| 921 |
+
"""
|
| 922 |
+
logger.debug(f"system_prompt: {system_prompt}")
|
| 923 |
+
logger.debug(f"user_prompt: {xml_input}")
|
| 924 |
+
messages = [
|
| 925 |
+
{"role": "system", "content": system_prompt},
|
| 926 |
+
{"role": "user", "content": xml_input}
|
| 927 |
+
]
|
| 928 |
+
max_retries = self.max_retries
|
| 929 |
+
trial = 0
|
| 930 |
+
current_delay = self.initial_retry_delay
|
| 931 |
+
response = None
|
| 932 |
+
while trial < max_retries:
|
| 933 |
+
try:
|
| 934 |
+
response = self.llm.chat_completion(messages=messages, max_tokens=4000, temperature=0.0)
|
| 935 |
+
llm_understanding = response['choices'][0]['message']['content']
|
| 936 |
+
logger.debug(f"llm response: {llm_understanding}")
|
| 937 |
+
return llm_understanding
|
| 938 |
+
except Exception as e:
|
| 939 |
+
trial += 1
|
| 940 |
+
if trial == self.max_retries:
|
| 941 |
+
logger.error(f"Failed to get data by llm api, {self.max_retries} attempts: {str(e)}")
|
| 942 |
+
raise
|
| 943 |
+
logger.warning(f"Failed to get data by llm api, Request failed (attempt {trial}/{self.max_retries}). "
|
| 944 |
+
f"Retrying in {current_delay} seconds. Error: {str(e)}")
|
| 945 |
+
if response is not None:
|
| 946 |
+
logger.warning(f"Response={response}")
|
| 947 |
+
time.sleep(current_delay)
|
| 948 |
+
current_delay *= 2
|
| 949 |
+
|
| 950 |
+
def get_comparison_result_by_id(self, item_id: str) -> List[Dict]:
|
| 951 |
+
if self.es_index is None:
|
| 952 |
+
if self.local_cache_enabled and not self._local_cache_loaded:
|
| 953 |
+
self.load_data_to_cache_from_local(None, load_detail=True)
|
| 954 |
+
item_id = str(item_id)
|
| 955 |
+
results = []
|
| 956 |
+
with self._cache_lock:
|
| 957 |
+
for pair_results in self.comparison_results.values():
|
| 958 |
+
for r in pair_results:
|
| 959 |
+
if str(r.get("item_id_a")) == item_id or str(r.get("item_id_b")) == item_id:
|
| 960 |
+
results.append(r)
|
| 961 |
+
return results
|
| 962 |
+
trial = 0
|
| 963 |
+
current_delay = self.initial_retry_delay
|
| 964 |
+
|
| 965 |
+
while trial < self.max_retries:
|
| 966 |
+
try:
|
| 967 |
+
query_body = {
|
| 968 |
+
"query": {
|
| 969 |
+
"bool": {
|
| 970 |
+
"should": [
|
| 971 |
+
{ "term": { "item_id_a.keyword": item_id }},
|
| 972 |
+
{ "term": { "item_id_b.keyword": item_id }}
|
| 973 |
+
],
|
| 974 |
+
"minimum_should_match": 1
|
| 975 |
+
}
|
| 976 |
+
}
|
| 977 |
+
}
|
| 978 |
+
result = self.client.search(index=self.es_index, body=query_body, size=200)
|
| 979 |
+
compare_result = []
|
| 980 |
+
for hit in result['hits']['hits']:
|
| 981 |
+
r = {
|
| 982 |
+
"judgment": hit['_source']['judgment'],
|
| 983 |
+
"raw_response": hit['_source']['raw_response'],
|
| 984 |
+
"timestamp": hit['_source']["timestamp"],
|
| 985 |
+
"item_id_a": hit['_source']['item_id_a'],
|
| 986 |
+
"item_id_b": hit['_source']['item_id_b'],
|
| 987 |
+
'ordered_ids': self.get_ordered_pair(hit['_source']['item_id_a'], hit['_source']['item_id_b']),
|
| 988 |
+
'original_ids': (hit['_source']['item_id_a'], hit['_source']['item_id_b'])
|
| 989 |
+
|
| 990 |
+
}
|
| 991 |
+
if self.detect_msg_violations:
|
| 992 |
+
r["violative_msg_map_str"] = hit['_source']['violative_msg_map_str']
|
| 993 |
+
|
| 994 |
+
compare_result.append(r)
|
| 995 |
+
|
| 996 |
+
return compare_result
|
| 997 |
+
|
| 998 |
+
except Exception as e:
|
| 999 |
+
trial += 1
|
| 1000 |
+
if trial == self.max_retries:
|
| 1001 |
+
logger.error(f"Failed to get data from es after {self.max_retries} attempts: {str(e)}")
|
| 1002 |
+
raise
|
| 1003 |
+
logger.warning(f"Failed to get data from es, Request failed (attempt {trial}/{self.max_retries}). "
|
| 1004 |
+
f"Retrying in {current_delay} seconds. Error: {str(e)}")
|
| 1005 |
+
time.sleep(current_delay)
|
| 1006 |
+
current_delay *= 2
|
| 1007 |
+
|
| 1008 |
+
|
| 1009 |
+
def get_comparison_count(self, pair_key: str) -> int:
|
| 1010 |
+
"""Get the number of comparisons for a pair.
|
| 1011 |
+
|
| 1012 |
+
Args:
|
| 1013 |
+
pair_key: Key for the pair
|
| 1014 |
+
|
| 1015 |
+
Returns:
|
| 1016 |
+
int: Number of comparisons for the pair
|
| 1017 |
+
"""
|
| 1018 |
+
return len(self.get_comparison_results(pair_key))
|
| 1019 |
+
|
| 1020 |
+
def get_comparison_results(self, pair_key: str) -> List[Dict]:
|
| 1021 |
+
"""Get comparison results for a pair.
|
| 1022 |
+
|
| 1023 |
+
Args:
|
| 1024 |
+
pair_key: Key for the pair
|
| 1025 |
+
|
| 1026 |
+
Returns:
|
| 1027 |
+
List[Dict]: List of comparison results for the pair
|
| 1028 |
+
"""
|
| 1029 |
+
return self.get_compare_result_from_es(pair_key)
|
| 1030 |
+
|
| 1031 |
+
def get_compare_information(self, item_ids: List[str], use_tqdm: bool = True) -> Dict[str, Any]:
|
| 1032 |
+
"""Get comprehensive comparison information between specified items.
|
| 1033 |
+
|
| 1034 |
+
Args:
|
| 1035 |
+
item_ids: List of item IDs to get comparison information for
|
| 1036 |
+
|
| 1037 |
+
Returns:
|
| 1038 |
+
Dict containing:
|
| 1039 |
+
- comparison_results: Dict mapping pair keys to their comparison results
|
| 1040 |
+
- comparison_counts: Dict mapping item IDs to their total comparison counts
|
| 1041 |
+
- pair_comparison_counts: Dict mapping pair keys to their comparison counts
|
| 1042 |
+
"""
|
| 1043 |
+
|
| 1044 |
+
comparison_results = {}
|
| 1045 |
+
comparison_counts = {}
|
| 1046 |
+
pair_comparison_counts = {}
|
| 1047 |
+
item_id_set = set(item_ids)
|
| 1048 |
+
with self._cache_lock:
|
| 1049 |
+
for pair_key, pair_results in tqdm(self.comparison_results.items(), desc="Scan Comparing pairs", total=len(self.comparison_results), leave=False, disable=not use_tqdm):
|
| 1050 |
+
id1, id2 = pair_key.split('_')
|
| 1051 |
+
if id1 in item_id_set and id2 in item_id_set:
|
| 1052 |
+
comparison_results[pair_key] = list(pair_results)
|
| 1053 |
+
comparison_counts[id1] = comparison_counts.get(id1, 0) + len(pair_results)
|
| 1054 |
+
comparison_counts[id2] = comparison_counts.get(id2, 0) + len(pair_results)
|
| 1055 |
+
pair_comparison_counts[pair_key] = len(pair_results)
|
| 1056 |
+
return {
|
| 1057 |
+
'comparison_results': comparison_results,
|
| 1058 |
+
'comparison_counts': comparison_counts,
|
| 1059 |
+
'pair_comparison_counts': pair_comparison_counts
|
| 1060 |
+
}
|