| |
| |
| """ |
| Qwen3-Reranker 推理测试代码 |
| 使用 RKLLM API 进行文本重排序推理 |
| """ |
|
|
| import faulthandler |
| faulthandler.enable() |
| import os |
| os.environ["RKLLM_LOG_LEVEL"] = "1" |
| import numpy as np |
| import time |
| import re |
| from typing import List, Dict, Any, Tuple |
| from rkllm_binding import * |
|
|
|
|
| class Qwen3RerankerTester: |
| def __init__(self, model_path, library_path="./librkllmrt.so"): |
| """ |
| 初始化 Qwen3 重排序模型测试器 |
| |
| Args: |
| model_path: 模型文件路径(.rkllm 格式) |
| library_path: RKLLM 库文件路径 |
| """ |
| self.model_path = model_path |
| self.library_path = library_path |
| self.runtime = None |
| self.current_result = None |
| |
| |
| self.system_prompt = "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\"." |
| |
| |
| |
| self.yes_token_candidates = [9693] |
| self.no_token_candidates = [2152] |
| |
| def callback_function(self, result_ptr, userdata_ptr, state_enum): |
| """ |
| 推理回调函数 |
| |
| Args: |
| result_ptr: 结果指针 |
| userdata_ptr: 用户数据指针 |
| state_enum: 状态枚举 |
| """ |
| state = LLMCallState(state_enum) |
| |
| if state == LLMCallState.RKLLM_RUN_NORMAL: |
| result = result_ptr.contents |
| print(f"result: {result}") |
| |
| |
| if result.logits.logits and result.logits.vocab_size > 0: |
| vocab_size = result.logits.vocab_size |
| num_tokens = result.logits.num_tokens |
| |
| print(f"获取到 logits:vocab_size={vocab_size}, num_tokens={num_tokens}") |
| |
| |
| if num_tokens > 0: |
| last_token_logits = [] |
| start_idx = (num_tokens - 1) * vocab_size |
| for i in range(vocab_size): |
| last_token_logits.append(result.logits.logits[start_idx + i]) |
| |
| self.current_result = { |
| 'logits': last_token_logits, |
| 'vocab_size': vocab_size, |
| 'num_tokens': num_tokens |
| } |
| |
| print(f"最后一个 token 的 logits 范围: [{min(last_token_logits):.4f}, {max(last_token_logits):.4f}]") |
| else: |
| print("警告: 未能获取到 logits") |
| |
| elif state == LLMCallState.RKLLM_RUN_ERROR: |
| print("推理过程发生错误") |
| |
| def find_best_yes_no_tokens(self, logits): |
| """ |
| 找到最可能的 "yes" 和 "no" token IDs |
| |
| Args: |
| logits: 词汇表大小的 logits 数组 |
| |
| Returns: |
| (yes_token_id, no_token_id, yes_logit, no_logit) |
| """ |
| vocab_size = len(logits) |
| |
| |
| best_yes_id = None |
| best_yes_logit = float('-inf') |
| for token_id in self.yes_token_candidates: |
| if token_id < vocab_size: |
| if logits[token_id] > best_yes_logit: |
| best_yes_logit = logits[token_id] |
| best_yes_id = token_id |
| |
| |
| best_no_id = None |
| best_no_logit = float('-inf') |
| for token_id in self.no_token_candidates: |
| if token_id < vocab_size: |
| if logits[token_id] > best_no_logit: |
| best_no_logit = logits[token_id] |
| best_no_id = token_id |
| |
| |
| if best_yes_id is None or best_no_id is None: |
| print("警告: 使用启发式方法寻找 yes/no tokens") |
| |
| |
| sorted_indices = np.argsort(logits)[::-1] |
| top_tokens = sorted_indices[:20] |
| |
| |
| if best_yes_id is None: |
| best_yes_id = top_tokens[0] |
| best_yes_logit = logits[best_yes_id] |
| |
| if best_no_id is None: |
| |
| best_no_id = top_tokens[min(10, len(top_tokens)-1)] |
| best_no_logit = logits[best_no_id] |
| |
| return best_yes_id, best_no_id, best_yes_logit, best_no_logit |
| |
| def calculate_reranker_score(self, logits): |
| """ |
| 计算重排序分数(基于 "yes" 和 "no" token 的 softmax 概率) |
| |
| Args: |
| logits: 词汇表大小的 logits 数组 |
| |
| Returns: |
| 相关性分数 (0-1之间,越高越相关) |
| """ |
| try: |
| |
| yes_id, no_id, yes_logit, no_logit = self.find_best_yes_no_tokens(logits) |
| |
| print(f"Yes token ID: {yes_id}, logit: {yes_logit:.4f}") |
| print(f"No token ID: {no_id}, logit: {no_logit:.4f}") |
| |
| |
| |
| max_logit = max(yes_logit, no_logit) |
| yes_exp = np.exp(yes_logit - max_logit) |
| no_exp = np.exp(no_logit - max_logit) |
| |
| sum_exp = yes_exp + no_exp |
| yes_prob = yes_exp / sum_exp |
| no_prob = no_exp / sum_exp |
| |
| print(f"Yes 概率: {yes_prob:.4f}, No 概率: {no_prob:.4f}") |
| |
| |
| return float(yes_prob) |
| |
| except Exception as e: |
| print(f"计算 reranker 分数时发生错误: {e}") |
| |
| return self.fallback_score_calculation(logits) |
| |
| def fallback_score_calculation(self, logits): |
| """ |
| 备用分数计算方法(当无法找到 yes/no tokens 时) |
| |
| Args: |
| logits: 词汇表大小的 logits 数组 |
| |
| Returns: |
| 相关性分数 (0-1之间) |
| """ |
| print("使用备用分数计算方法") |
| |
| |
| logits_array = np.array(logits) |
| |
| |
| softmax_probs = np.exp(logits_array - np.max(logits_array)) |
| softmax_probs = softmax_probs / np.sum(softmax_probs) |
| |
| |
| entropy = -np.sum(softmax_probs * np.log(softmax_probs + 1e-10)) |
| max_entropy = np.log(len(logits)) |
| normalized_entropy = entropy / max_entropy |
| |
| |
| confidence_score = 1.0 - normalized_entropy |
| |
| |
| max_logit_score = (np.max(logits_array) - np.mean(logits_array)) / (np.std(logits_array) + 1e-8) |
| max_logit_score = max(0, min(1, max_logit_score / 10)) |
| |
| |
| final_score = 0.7 * confidence_score + 0.3 * max_logit_score |
| final_score = max(0.0, min(1.0, final_score)) |
| |
| print(f"备用计算 - 熵分数: {confidence_score:.4f}, 最大logit分数: {max_logit_score:.4f}, 最终分数: {final_score:.4f}") |
| |
| return final_score |
| |
| def init_model(self): |
| """初始化模型""" |
| try: |
| print(f"初始化 RKLLM 运行时,库路径: {self.library_path}") |
| self.runtime = RKLLMRuntime(self.library_path) |
| |
| print("创建默认参数...") |
| params = self.runtime.create_default_param() |
| |
| |
| params.model_path = self.model_path.encode('utf-8') |
| params.max_context_len = 1024 |
| params.max_new_tokens = 1 |
| params.temperature = 0.0 |
| params.top_k = 1 |
| params.top_p = 1.0 |
| |
| |
| params.extend_param.base_domain_id = 1 |
| params.extend_param.embed_flash = 0 |
| params.extend_param.enabled_cpus_num = 4 |
| params.extend_param.enabled_cpus_mask = 0x0F |
| |
| print(f"初始化模型: {self.model_path}") |
| self.runtime.init(params, self.callback_function) |
| |
| |
| self.runtime.set_chat_template( |
| "", |
| "", |
| "" |
| ) |
| |
| print("模型初始化成功!") |
| |
| except Exception as e: |
| print(f"模型初始化失败: {e}") |
| raise |
| |
| def format_rerank_input(self, instruction, query, document): |
| """ |
| 格式化重排序输入(根据官方 README 格式) |
| |
| Args: |
| instruction: 任务指令 |
| query: 查询文本 |
| document: 文档文本 |
| |
| Returns: |
| 格式化的输入文本 |
| """ |
| if instruction is None: |
| instruction = 'Given a web search query, retrieve relevant passages that answer the query' |
| |
| |
| formatted_input = f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}" |
| return formatted_input |
| |
| def get_reranker_score(self, instruction, query, document): |
| """ |
| 获取重排序分数(通过 logits) |
| |
| Args: |
| instruction: 任务指令 |
| query: 查询文本 |
| document: 文档文本 |
| |
| Returns: |
| 相关性分数 (0-1之间) |
| """ |
| try: |
| |
| input_text = self.format_rerank_input(instruction, query, document) |
| print(f"\n重排序输入: {input_text[:200]}{'...' if len(input_text) > 200 else ''}") |
| |
| |
| rk_input = RKLLMInput() |
| rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT |
| c_prompt = input_text.encode('utf-8') |
| rk_input._union_data.prompt_input = c_prompt |
| |
| |
| infer_params = RKLLMInferParam() |
| infer_params.mode = RKLLMInferMode.RKLLM_INFER_GET_LOGITS |
| infer_params.keep_history = 0 |
| |
| |
| self.current_result = None |
| self.runtime.clear_kv_cache(False) |
| |
| |
| start_time = time.time() |
| self.runtime.run(rk_input, infer_params) |
| end_time = time.time() |
| |
| print(f"\n推理耗时: {end_time - start_time:.3f}秒") |
| |
| if self.current_result and 'logits' in self.current_result: |
| |
| logits = self.current_result['logits'] |
| score = self.calculate_reranker_score(logits) |
| |
| print(f"计算得分: {score:.4f}") |
| return score |
| else: |
| print("警告: 未能获取到有效的 logits,返回默认分数") |
| return 0.0 |
| |
| except Exception as e: |
| print(f"重排序评分时发生错误: {e}") |
| import traceback |
| traceback.print_exc() |
| return 0.0 |
| |
| def rerank_documents(self, query, documents, instruction=None): |
| """ |
| 对文档列表进行重排序 |
| |
| Args: |
| query: 查询文本 |
| documents: 文档列表 |
| instruction: 可选的任务指令 |
| |
| Returns: |
| 按相关性分数降序排列的(文档, 分数)元组列表 |
| """ |
| print(f"\n对 {len(documents)} 个文档进行重排序") |
| print(f"查询: {query}") |
| |
| if instruction: |
| print(f"指令: {instruction}") |
| |
| scored_docs = [] |
| for i, doc in enumerate(documents): |
| print(f"\n--- 处理文档 {i+1}/{len(documents)} ---") |
| print(f"文档: {doc[:100]}{'...' if len(doc) > 100 else ''}") |
| |
| score = self.get_reranker_score(instruction, query, doc) |
| scored_docs.append((doc, score)) |
| print(f"得分: {score:.4f}") |
| |
| |
| scored_docs.sort(key=lambda x: x[1], reverse=True) |
| return scored_docs |
| |
| def test_basic_reranking(self): |
| """测试基础重排序功能""" |
| print("\n" + "="*60) |
| print("测试基础重排序功能") |
| print("="*60) |
| |
| |
| query = "What is the capital of China?" |
| |
| |
| documents = [ |
| "Beijing is the capital city of China, located in northern China.", |
| "The Great Wall of China is an ancient fortification built to protect Chinese states.", |
| "Python is a high-level programming language used for software development.", |
| "China's capital Beijing is home to over 21 million people.", |
| "Machine learning is a subset of artificial intelligence that uses algorithms." |
| ] |
| |
| |
| instruction = "Given a web search query, retrieve relevant passages that answer the query" |
| ranked_docs = self.rerank_documents(query, documents, instruction) |
| |
| |
| print(f"\n重排序结果(查询: {query}):") |
| print("-" * 80) |
| for i, (doc, score) in enumerate(ranked_docs): |
| print(f"排名 {i+1}: 分数 {score:.4f}") |
| print(f"文档: {doc}") |
| print() |
| |
| return ranked_docs |
| |
| def test_multilingual_reranking(self): |
| """测试多语言重排序""" |
| print("\n" + "="*60) |
| print("测试多语言重排序功能") |
| print("="*60) |
| |
| |
| query = "中国的首都是什么?" |
| |
| documents = [ |
| "北京是中华人民共和国的首都,位于中国北部。", |
| "上海是中国的经济中心,人口超过2400万。", |
| "Python 是一种高级编程语言。", |
| "The capital of China is Beijing.", |
| "长城是中国古代的军事防御工程。" |
| ] |
| |
| instruction = "Given a web search query, retrieve relevant passages that answer the query" |
| ranked_docs = self.rerank_documents(query, documents, instruction) |
| |
| print(f"\n多语言重排序结果(查询: {query}):") |
| print("-" * 80) |
| for i, (doc, score) in enumerate(ranked_docs): |
| print(f"排名 {i+1}: 分数 {score:.4f}") |
| print(f"文档: {doc}") |
| print() |
| |
| return ranked_docs |
| |
| def test_domain_specific_reranking(self): |
| """测试领域特定的重排序""" |
| print("\n" + "="*60) |
| print("测试领域特定重排序(技术文档)") |
| print("="*60) |
| |
| query = "How to implement a neural network in Python?" |
| |
| documents = [ |
| "PyTorch is a deep learning framework that provides tensor computations with GPU acceleration.", |
| "TensorFlow is an open-source machine learning library developed by Google.", |
| "Neural networks are computing systems inspired by biological neural networks.", |
| "Python is a programming language with simple syntax and powerful libraries.", |
| "To implement a neural network in Python, you can use libraries like PyTorch or TensorFlow to define layers, loss functions, and optimization algorithms.", |
| "Cooking recipes often require precise measurements and cooking times.", |
| "Backpropagation is the algorithm used to train neural networks by computing gradients." |
| ] |
| |
| |
| instruction = "Given a technical query and a document, determine if the document provides practical information for implementing the requested technical solution" |
| |
| ranked_docs = self.rerank_documents(query, documents, instruction) |
| |
| print(f"\n技术文档重排序结果(查询: {query}):") |
| print("-" * 80) |
| for i, (doc, score) in enumerate(ranked_docs): |
| print(f"排名 {i+1}: 分数 {score:.4f}") |
| print(f"文档: {doc}") |
| print() |
| |
| return ranked_docs |
| |
| def test_comparison_with_official_example(self): |
| """测试与官方示例的对比""" |
| print("\n" + "="*60) |
| print("测试与官方示例的对比") |
| print("="*60) |
| |
| |
| task = 'Given a web search query, retrieve relevant passages that answer the query' |
| |
| queries = [ |
| "What is the capital of China?", |
| "Explain gravity", |
| ] |
| |
| documents = [ |
| "The capital of China is Beijing.", |
| "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.", |
| ] |
| |
| print("测试官方示例的查询-文档对:") |
| for i, (query, doc) in enumerate(zip(queries, documents)): |
| print(f"\n=== 查询-文档对 {i+1} ===") |
| print(f"查询: {query}") |
| print(f"文档: {doc}") |
| |
| score = self.get_reranker_score(task, query, doc) |
| print(f"相关性分数: {score:.4f}") |
| |
| def cleanup(self): |
| """清理资源""" |
| if self.runtime: |
| try: |
| self.runtime.destroy() |
| print("模型资源已清理") |
| except Exception as e: |
| print(f"清理资源时发生错误: {e}") |
|
|
|
|
| def main(): |
| """主函数""" |
| import argparse |
| |
| |
| parser = argparse.ArgumentParser(description='Qwen3-Reranker-0.6B 推理测试') |
| parser.add_argument('model_path', help='模型文件路径(.rkllm格式)') |
| parser.add_argument('--library_path', default="./librkllmrt.so", help='RKLLM库文件路径(默认为./librkllmrt.so)') |
| args = parser.parse_args() |
| |
| |
| if not os.path.exists(args.model_path): |
| print(f"错误: 模型文件不存在: {args.model_path}") |
| print("请确保:") |
| print("1. 已下载 Qwen3-Reranker-0.6B 模型") |
| print("2. 已使用 rkllm-convert.py 将模型转换为 .rkllm 格式") |
| return |
| |
| if not os.path.exists(args.library_path): |
| print(f"错误: RKLLM 库文件不存在: {args.library_path}") |
| print("请确保 librkllmrt.so 在当前目录或 LD_LIBRARY_PATH 中") |
| return |
| |
| print("Qwen3-Reranker-0.6B 推理测试") |
| print("=" * 60) |
| print("基于官方 README 的正确实现") |
| print("=" * 60) |
| |
| |
| tester = Qwen3RerankerTester(args.model_path, args.library_path) |
| |
| try: |
| |
| tester.init_model() |
| |
| |
| print("\n开始运行重排序测试...") |
| |
| |
| tester.test_comparison_with_official_example() |
| |
| |
| tester.test_basic_reranking() |
| |
| |
| tester.test_multilingual_reranking() |
| |
| |
| tester.test_domain_specific_reranking() |
| |
| print("\n" + "="*60) |
| print("所有重排序测试完成!") |
| print("="*60) |
| |
| except KeyboardInterrupt: |
| print("\n测试被用户中断") |
| except Exception as e: |
| print(f"\n测试过程中发生错误: {e}") |
| import traceback |
| traceback.print_exc() |
| finally: |
| |
| tester.cleanup() |
|
|
|
|
| if __name__ == "__main__": |
| main() |