| | from typing import List, Dict
|
| | import requests
|
| | import time
|
| | from dotenv import load_dotenv
|
| | import os
|
| |
|
| | load_dotenv()
|
| |
|
| | class Reranker:
|
| | def __init__(self):
|
| | self.api_key = os.getenv("API_KEY")
|
| | self.api_base = os.getenv("BASE_URL")
|
| |
|
| | def rerank(self, query: str, documents: List[Dict], index_name: str, top_k: int = 5) -> List[Dict]:
|
| | """使用SiliconFlow的rerank API重排序文档"""
|
| | headers = {
|
| | "Authorization": f"Bearer {self.api_key}",
|
| | "Content-Type": "application/json"
|
| | }
|
| |
|
| |
|
| | docs = [doc['content'] for doc in documents]
|
| |
|
| | response = requests.post(
|
| | f"{self.api_base}/rerank",
|
| | headers=headers,
|
| | json={
|
| | "model": "BAAI/bge-reranker-v2-m3",
|
| | "query": query,
|
| | "documents": docs,
|
| | "top_n": top_k
|
| | }
|
| | )
|
| |
|
| | if response.status_code != 200:
|
| | raise Exception(f"Error in reranking: {response.text}")
|
| |
|
| |
|
| | results = response.json()["results"]
|
| | reranked_docs = []
|
| |
|
| | for result in results:
|
| | doc_index = result["index"]
|
| | original_doc = documents[doc_index].copy()
|
| | original_doc['rerank_score'] = result["relevance_score"]
|
| | original_doc['index_name'] = index_name
|
| | reranked_docs.append(original_doc)
|
| |
|
| | return reranked_docs |