| from typing import List, Dict
|
| import requests
|
| import time
|
| from dotenv import load_dotenv
|
| import os
|
|
|
| load_dotenv()
|
|
|
| class Reranker:
|
| def __init__(self):
|
| self.api_key = os.getenv("API_KEY")
|
| self.api_base = os.getenv("BASE_URL")
|
|
|
| def rerank(self, query: str, documents: List[Dict], index_name: str, top_k: int = 5) -> List[Dict]:
|
| """使用SiliconFlow的rerank API重排序文档"""
|
| headers = {
|
| "Authorization": f"Bearer {self.api_key}",
|
| "Content-Type": "application/json"
|
| }
|
|
|
|
|
| docs = [doc['content'] for doc in documents]
|
|
|
| response = requests.post(
|
| f"{self.api_base}/rerank",
|
| headers=headers,
|
| json={
|
| "model": "BAAI/bge-reranker-v2-m3",
|
| "query": query,
|
| "documents": docs,
|
| "top_n": top_k
|
| }
|
| )
|
|
|
| if response.status_code != 200:
|
| raise Exception(f"Error in reranking: {response.text}")
|
|
|
|
|
| results = response.json()["results"]
|
| reranked_docs = []
|
|
|
| for result in results:
|
| doc_index = result["index"]
|
| original_doc = documents[doc_index].copy()
|
| original_doc['rerank_score'] = result["relevance_score"]
|
| original_doc['index_name'] = index_name
|
| reranked_docs.append(original_doc)
|
|
|
| return reranked_docs |