| | import json |
| | from chromadb import Client, Settings, EmbeddingFunction |
| | from pprint import pprint |
| | import random |
| | import os |
| | from sentence_transformers import SentenceTransformer |
| | import torch |
| |
|
| | |
| | CHROMA_URI = "./Data/database" |
| | EMBEDDING_MODEL_NAME = "pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb" |
| | VECTOR_DIM = 768 |
| |
|
| | class BioEmbeddingFunction(EmbeddingFunction): |
| | def __init__(self): |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self.model = SentenceTransformer(EMBEDDING_MODEL_NAME) |
| | self.model.to(self.device) |
| | |
| | def __call__(self, input: list[str]) -> list[list[float]]: |
| | embeddings = self.model.encode( |
| | input, |
| | normalize_embeddings=True, |
| | convert_to_numpy=True |
| | ) |
| | return embeddings.tolist() |
| |
|
| | def test_database(): |
| | print("="*50) |
| | print("开始测试数据库") |
| | print("="*50) |
| | |
| | |
| | client = Client(Settings( |
| | persist_directory=CHROMA_URI, |
| | anonymized_telemetry=False, |
| | is_persistent=True |
| | )) |
| | |
| |
|
| | embedding_func = BioEmbeddingFunction() |
| | collection = client.get_or_create_collection( |
| | name="healthcare_qa", |
| | embedding_function=embedding_func |
| | ) |
| | |
| | |
| | print("\n1. 数据库基本信息:") |
| | print(f"数据库位置: {os.path.abspath(CHROMA_URI)}") |
| | print(f"数据库大小: {os.path.getsize(CHROMA_URI) / 1024 / 1024:.2f} MB") |
| | print(f"总条目数: {collection.count()} 条") |
| | print(f"使用的嵌入模型: {EMBEDDING_MODEL_NAME}") |
| | |
| | |
| | print("\n2. 随机样本展示:") |
| | total_items = collection.count() |
| | sample_size = min(2, total_items) |
| | random_indices = random.sample(range(total_items), sample_size) |
| | |
| | results = collection.get( |
| | ids=[str(i) for i in random_indices], |
| | include=["documents", "metadatas"] |
| | ) |
| | |
| | for i, (doc, metadata) in enumerate(zip(results['documents'], results['metadatas']), 1): |
| | print(f"\n样本 {i}:") |
| | print("-" * 40) |
| | print("文档内容:") |
| | print(doc) |
| | print("\n元数据:") |
| | pprint(metadata) |
| | print("-" * 40) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| | print("\n3. 测试查询功能:") |
| | query = "diabetes" |
| | results = collection.query( |
| | query_texts=[query], |
| | n_results=1, |
| | include=["documents", "metadatas", "distances"] |
| | ) |
| | |
| | print(f"\n使用查询词 '{query}' 的结果:") |
| | for i, (doc, metadata, distance) in enumerate(zip( |
| | results['documents'][0], |
| | results['metadatas'][0], |
| | results['distances'][0] |
| | ), 1): |
| | print(f"\n结果 {i}:") |
| | print("-" * 40) |
| | print(f"相似度得分: {1 - distance:.4f}") |
| | print("\n文档内容:") |
| | print(doc) |
| | print("\n元数据:") |
| | pprint(metadata) |
| | print("-" * 40) |
| | |
| |
|
| |
|
| | |
| |
|
| | if __name__ == "__main__": |
| | test_database() |
| |
|