| | import numpy as np |
| | from chromadb import Client, Settings |
| | from sklearn.decomposition import PCA |
| | import joblib |
| | import os |
| | from datetime import datetime |
| | import warnings |
| | import cupy as cp |
| | from cuml.cluster import KMeans as cuKMeans |
| | from tqdm import tqdm |
| | |
| | warnings.filterwarnings('ignore', category=FutureWarning) |
| | warnings.filterwarnings('ignore', category=UserWarning) |
| |
|
| | class TopicClusterer: |
| | def __init__(self, chroma_uri: str = "./Data/database"): |
| | """初始化聚类器 |
| | |
| | Args: |
| | chroma_uri: ChromaDB数据库路径 |
| | """ |
| | self.chroma_uri = chroma_uri |
| | self.client = Client(Settings( |
| | persist_directory=chroma_uri, |
| | anonymized_telemetry=False, |
| | is_persistent=True |
| | )) |
| | |
| | self.vector_dim = 768 |
| | |
| | |
| | try: |
| | self.collection = self.client.get_collection("healthcare_qa") |
| | except Exception as e: |
| | print(f"集合不存在") |
| | |
| | self.embeddings = None |
| | self.reduced_embeddings = None |
| | self.labels = None |
| | self.document_ids = None |
| | |
| | def load_embeddings(self) -> np.ndarray: |
| | """从数据库加载embeddings""" |
| | |
| | embeddings_cache_file = '/home/dyvm6xra/dyvm6xrauser11/workspace/projects/HKU/Chatbot/Data/Embeddings/embeddings_703df19c43bd6565563071b97e7172ce.npy' |
| | |
| | |
| | if os.path.exists(embeddings_cache_file) and 0: |
| | print("发现缓存的embeddings,正在加载...") |
| | try: |
| | self.embeddings = np.load(embeddings_cache_file) |
| | self.document_ids = [str(i) for i in range(len(self.embeddings))] |
| | print(f"从缓存加载完成,数据形状: {self.embeddings.shape}") |
| | return self.embeddings |
| | except Exception as e: |
| | print(f"加载缓存失败: {e},将从数据库重新加载") |
| | else: |
| | print("正在加载embeddings...") |
| | print(self.collection.count()) |
| | result = self.collection.get(include=["embeddings"]) |
| | self.embeddings = np.array(result["embeddings"]) |
| | self.document_ids = result["ids"] |
| | |
| | print(f"加载完成,数据形状: {self.embeddings.shape}") |
| | return self.embeddings |
| | |
| | def reduce_dimensions(self, n_components: int = 2) -> np.ndarray: |
| | """使用PCA进行降维 |
| | |
| | Args: |
| | n_components: 降维后的维度 |
| | """ |
| | if self.embeddings is None: |
| | self.load_embeddings() |
| | |
| | print("使用PCA进行降维...") |
| | |
| | |
| | reducer = PCA( |
| | n_components=n_components, |
| | random_state=42, |
| | svd_solver='randomized' |
| | ) |
| | self.reduced_embeddings = reducer.fit_transform(self.embeddings) |
| | cumulative_variance = np.cumsum(reducer.explained_variance_ratio_) |
| | print(f"PCA累积解释方差比: {cumulative_variance[-1]:.4f}") |
| | |
| | print(f"降维完成,降维后形状: {self.reduced_embeddings.shape}") |
| | |
| | |
| | cache_dir = os.path.dirname(os.path.dirname(self.chroma_uri)) + '/Embeddings' |
| | os.makedirs(cache_dir, exist_ok=True) |
| | cache_file = os.path.join(cache_dir, f'pca_reduced_{n_components}d.npy') |
| | np.save(cache_file, self.reduced_embeddings) |
| | print(f"降维结果已缓存到: {cache_file}") |
| | |
| | return self.reduced_embeddings |
| | |
| | def cluster_kmeans(self, n_clusters: int = 4) -> np.ndarray: |
| | """使用KMeans进行聚类 |
| | |
| | Args: |
| | n_clusters: 聚类数 |
| | """ |
| | print("使用GPU加速的KMeans进行聚类...") |
| | |
| | |
| | if self.reduced_embeddings is None: |
| | self.reduce_dimensions() |
| | |
| | |
| | data_gpu = cp.array(self.reduced_embeddings) |
| | |
| | |
| | kmeans = cuKMeans( |
| | n_clusters=n_clusters, |
| | random_state=42, |
| | n_init=10, |
| | max_iter=300, |
| | verbose=1 |
| | ) |
| | kmeans.fit(data_gpu) |
| | self.labels = cp.asnumpy(kmeans.labels_) |
| | |
| | |
| | unique_labels = np.unique(self.labels) |
| | n_clusters = len(unique_labels) |
| | |
| | print(f"发现 {n_clusters} 个聚类") |
| | for label in unique_labels: |
| | count = np.sum(self.labels == label) |
| | percentage = count / len(self.labels) * 100 |
| | print(f"簇 {label}: {count} 样本 ({percentage:.2f}%)") |
| | |
| | return self.labels |
| | |
| | def update_database(self) -> None: |
| | """将聚类结果写回数据库""" |
| | if self.labels is None or self.document_ids is None: |
| | raise ValueError("请先进行聚类") |
| | |
| | print("正在更新数据库...") |
| | |
| | |
| | label_strings = [f"cluster_{label}" for label in self.labels] |
| | |
| | |
| | batch_size = 500 |
| | total_docs = len(self.document_ids) |
| | |
| | for i in tqdm(range(0, total_docs, batch_size), desc="批量更新数据库"): |
| | batch_end = min(i + batch_size, total_docs) |
| | batch_ids = self.document_ids[i:batch_end] |
| | batch_labels = label_strings[i:batch_end] |
| | |
| | |
| | continue |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | print("数据库更新完成") |
| |
|
| | def main(): |
| | |
| | clusterer = TopicClusterer() |
| | |
| | |
| | clusterer.load_embeddings() |
| | |
| | |
| | clusterer.reduce_dimensions(n_components=2) |
| | |
| | |
| | clusterer.cluster_kmeans(n_clusters=4) |
| | |
| | |
| | clusterer.update_database() |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|