| |
| |
| """ |
| Qwen3-Embedding-0.6B 推理测试代码 |
| 使用 RKLLM API 进行文本嵌入推理 |
| """ |
| import faulthandler |
| faulthandler.enable() |
| import os |
| os.environ["RKLLM_LOG_LEVEL"] = "1" |
| import numpy as np |
| import time |
| from typing import List, Dict, Any |
| from rkllm_binding import * |
|
|
|
|
| class Qwen3EmbeddingTester: |
| def __init__(self, model_path: str, library_path: str = "./librkllmrt.so"): |
| """ |
| 初始化 Qwen3 嵌入模型测试器 |
| |
| Args: |
| model_path: 模型文件路径(.rkllm 格式) |
| library_path: RKLLM 库文件路径 |
| """ |
| self.model_path = model_path |
| self.library_path = library_path |
| self.runtime = None |
| self.embeddings_buffer = [] |
| self.current_result = None |
| |
| def callback_function(self, result_ptr, userdata_ptr, state_enum): |
| """ |
| 推理回调函数 |
| |
| Args: |
| result_ptr: 结果指针 |
| userdata_ptr: 用户数据指针 |
| state_enum: 状态枚举 |
| """ |
| state = LLMCallState(state_enum) |
| |
| if state == LLMCallState.RKLLM_RUN_NORMAL: |
| result = result_ptr.contents |
| print(f"result: {result}") |
| |
| if result.last_hidden_layer.hidden_states and result.last_hidden_layer.embd_size > 0: |
| embd_size = result.last_hidden_layer.embd_size |
| num_tokens = result.last_hidden_layer.num_tokens |
| |
| print(f"获取到嵌入向量:维度={embd_size}, 令牌数={num_tokens}") |
| |
| |
| |
| if num_tokens > 0: |
| |
| last_token_embedding = np.array([ |
| result.last_hidden_layer.hidden_states[(num_tokens-1) * embd_size + i] |
| for i in range(embd_size) |
| ]) |
| |
| self.current_result = { |
| 'embedding': last_token_embedding, |
| 'embd_size': embd_size, |
| 'num_tokens': num_tokens |
| } |
| |
| print(f"嵌入向量范数: {np.linalg.norm(last_token_embedding):.4f}") |
| print(f"嵌入向量前10维: {last_token_embedding[:10]}") |
| |
| elif state == LLMCallState.RKLLM_RUN_ERROR: |
| print("推理过程发生错误") |
| |
| def init_model(self): |
| """初始化模型""" |
| try: |
| print(f"初始化 RKLLM 运行时,库路径: {self.library_path}") |
| self.runtime = RKLLMRuntime(self.library_path) |
| |
| print("创建默认参数...") |
| params = self.runtime.create_default_param() |
| |
| |
| params.model_path = self.model_path.encode('utf-8') |
| params.max_context_len = 1024 |
| params.max_new_tokens = 1 |
| params.temperature = 1.0 |
| params.top_k = 1 |
| params.top_p = 1.0 |
| |
| |
| params.extend_param.base_domain_id = 1 |
| params.extend_param.embed_flash = 0 |
| params.extend_param.enabled_cpus_num = 4 |
| params.extend_param.enabled_cpus_mask = 0x0F |
| |
| print(f"初始化模型: {self.model_path}") |
| self.runtime.init(params, self.callback_function) |
| self.runtime.set_chat_template("","","") |
| print("模型初始化成功!") |
| |
| except Exception as e: |
| print(f"模型初始化失败: {e}") |
| raise |
| |
| def get_detailed_instruct(self, task_description: str, query: str) -> str: |
| """ |
| 构建指令提示词(参考 README 中的用法) |
| |
| Args: |
| task_description: 任务描述 |
| query: 查询文本 |
| |
| Returns: |
| 格式化的指令提示词 |
| """ |
| return f'Instruct: {task_description}\nQuery: {query}' |
| |
| def encode_text(self, text: str, task_description: str = None) -> np.ndarray: |
| """ |
| 编码文本为嵌入向量 |
| |
| Args: |
| text: 要编码的文本 |
| task_description: 任务描述,如果提供则使用指令提示 |
| |
| Returns: |
| 嵌入向量(numpy数组) |
| """ |
| try: |
| |
| if task_description: |
| input_text = self.get_detailed_instruct(task_description, text) |
| else: |
| input_text = text |
| |
| print(f"编码文本: {input_text[:100]}{'...' if len(input_text) > 100 else ''}") |
| |
| |
| rk_input = RKLLMInput() |
| rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT |
| c_prompt = input_text.encode('utf-8') |
| rk_input._union_data.prompt_input = c_prompt |
| |
| |
| infer_params = RKLLMInferParam() |
| infer_params.mode = RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER |
| infer_params.keep_history = 0 |
| |
| |
| self.current_result = None |
| self.runtime.clear_kv_cache(False) |
| |
| |
| start_time = time.time() |
| self.runtime.run(rk_input, infer_params) |
| end_time = time.time() |
| |
| print(f"推理耗时: {end_time - start_time:.3f}秒") |
| |
| if self.current_result and 'embedding' in self.current_result: |
| |
| embedding = self.current_result['embedding'] |
| normalized_embedding = embedding / np.linalg.norm(embedding) |
| return normalized_embedding |
| else: |
| raise RuntimeError("未能获取到有效的嵌入向量") |
| |
| except Exception as e: |
| print(f"编码文本时发生错误: {e}") |
| raise |
| |
| def compute_similarity(self, emb1: np.ndarray, emb2: np.ndarray) -> float: |
| """ |
| 计算两个嵌入向量的余弦相似度 |
| |
| Args: |
| emb1: 第一个嵌入向量 |
| emb2: 第二个嵌入向量 |
| |
| Returns: |
| 余弦相似度值 |
| """ |
| return np.dot(emb1, emb2) |
| |
| def test_embedding_similarity(self): |
| """测试嵌入相似度计算""" |
| print("\n" + "="*50) |
| print("测试嵌入相似度计算") |
| print("="*50) |
| |
| |
| task_description = "Given a web search query, retrieve relevant passages that answer the query" |
| |
| queries = [ |
| "What is the capital of China?", |
| "Explain gravity" |
| ] |
| |
| documents = [ |
| "The capital of China is Beijing.", |
| "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun." |
| ] |
| |
| |
| print("\n编码查询文本:") |
| query_embeddings = [] |
| for i, query in enumerate(queries): |
| print(f"\n查询 {i+1}: {query}") |
| emb = self.encode_text(query, task_description) |
| query_embeddings.append(emb) |
| |
| |
| print("\n编码文档文本:") |
| doc_embeddings = [] |
| for i, doc in enumerate(documents): |
| print(f"\n文档 {i+1}: {doc}") |
| emb = self.encode_text(doc) |
| doc_embeddings.append(emb) |
| |
| |
| print("\n计算相似度矩阵:") |
| print("查询 vs 文档相似度:") |
| print("-" * 30) |
| |
| similarities = [] |
| for i, q_emb in enumerate(query_embeddings): |
| row_similarities = [] |
| for j, d_emb in enumerate(doc_embeddings): |
| sim = self.compute_similarity(q_emb, d_emb) |
| row_similarities.append(sim) |
| print(f"查询{i+1} vs 文档{j+1}: {sim:.4f}") |
| similarities.append(row_similarities) |
| print() |
| |
| return similarities |
| |
| def test_multilingual_embedding(self): |
| """测试多语言嵌入能力""" |
| print("\n" + "="*50) |
| print("测试多语言嵌入能力") |
| print("="*50) |
| |
| |
| texts = { |
| "英语": "Hello, how are you?", |
| "中文": "你好,你好吗?", |
| "法语": "Bonjour, comment allez-vous?", |
| "西班牙语": "Hola, ¿cómo estás?", |
| "日语": "こんにちは、元気ですか?" |
| } |
| |
| embeddings = {} |
| print("\n编码多语言文本:") |
| for lang, text in texts.items(): |
| print(f"\n{lang}: {text}") |
| emb = self.encode_text(text) |
| embeddings[lang] = emb |
| |
| |
| print("\n跨语言相似度:") |
| print("-" * 30) |
| |
| languages = list(texts.keys()) |
| for i, lang1 in enumerate(languages): |
| for j, lang2 in enumerate(languages): |
| if i <= j: |
| sim = self.compute_similarity(embeddings[lang1], embeddings[lang2]) |
| print(f"{lang1} vs {lang2}: {sim:.4f}") |
| |
| def test_code_embedding(self): |
| """测试代码嵌入能力""" |
| print("\n" + "="*50) |
| print("测试代码嵌入能力") |
| print("="*50) |
| |
| |
| codes = { |
| "Python函数": """ |
| def fibonacci(n): |
| if n <= 1: |
| return n |
| return fibonacci(n-1) + fibonacci(n-2) |
| """, |
| "JavaScript函数": """ |
| function fibonacci(n) { |
| if (n <= 1) return n; |
| return fibonacci(n-1) + fibonacci(n-2); |
| } |
| """, |
| "C++函数": """ |
| int fibonacci(int n) { |
| if (n <= 1) return n; |
| return fibonacci(n-1) + fibonacci(n-2); |
| } |
| """, |
| "数组排序": """ |
| def bubble_sort(arr): |
| n = len(arr) |
| for i in range(n): |
| for j in range(0, n-i-1): |
| if arr[j] > arr[j+1]: |
| arr[j], arr[j+1] = arr[j+1], arr[j] |
| """ |
| } |
| |
| embeddings = {} |
| print("\n编码代码文本:") |
| for name, code in codes.items(): |
| print(f"\n{name}:") |
| print(code[:100] + "..." if len(code) > 100 else code) |
| emb = self.encode_text(code) |
| embeddings[name] = emb |
| |
| |
| print("\n代码相似度:") |
| print("-" * 30) |
| |
| code_names = list(codes.keys()) |
| for i, name1 in enumerate(code_names): |
| for j, name2 in enumerate(code_names): |
| if i <= j: |
| sim = self.compute_similarity(embeddings[name1], embeddings[name2]) |
| print(f"{name1} vs {name2}: {sim:.4f}") |
| |
| def cleanup(self): |
| """清理资源""" |
| if self.runtime: |
| try: |
| self.runtime.destroy() |
| print("模型资源已清理") |
| except Exception as e: |
| print(f"清理资源时发生错误: {e}") |
|
|
| def main(): |
| """主函数""" |
| import argparse |
| |
| |
| parser = argparse.ArgumentParser(description='Qwen3-Embedding-0.6B 推理测试') |
| parser.add_argument('model_path', help='模型文件路径(.rkllm格式)') |
| parser.add_argument('--library_path', default="./librkllmrt.so", help='RKLLM库文件路径(默认为./librkllmrt.so)') |
| args = parser.parse_args() |
| |
| |
| if not os.path.exists(args.model_path): |
| print(f"错误: 模型文件不存在: {args.model_path}") |
| print("请确保:") |
| print("1. 已下载 Qwen3-Embedding-0.6B 模型") |
| print("2. 已使用 rkllm-convert.py 将模型转换为 .rkllm 格式") |
| return |
| |
| if not os.path.exists(args.library_path): |
| print(f"错误: RKLLM 库文件不存在: {args.library_path}") |
| print("请确保 librkllmrt.so 在当前目录或 LD_LIBRARY_PATH 中") |
| return |
| |
| print("Qwen3-Embedding-0.6B 推理测试") |
| print("=" * 50) |
| |
| |
| tester = Qwen3EmbeddingTester(args.model_path, args.library_path) |
| |
| try: |
| |
| tester.init_model() |
| |
| |
| print("\n开始运行嵌入测试...") |
| |
| |
| tester.test_embedding_similarity() |
| |
| |
| tester.test_multilingual_embedding() |
| |
| |
| tester.test_code_embedding() |
| |
| print("\n" + "="*50) |
| print("所有测试完成!") |
| print("="*50) |
| |
| except KeyboardInterrupt: |
| print("\n测试被用户中断") |
| except Exception as e: |
| print(f"\n测试过程中发生错误: {e}") |
| import traceback |
| traceback.print_exc() |
| finally: |
| |
| tester.cleanup() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|