| | import os |
| | import chromadb |
| | from chromadb.config import Settings |
| | from typing import List, Dict, Optional |
| | from config import Config |
| | from core.openai_client import OpenAIClient |
| |
|
| | class MemoryManager: |
| | """向量记忆管理器 - 存储和检索角色相关的文本片段""" |
| | |
| | def __init__(self, character_name: str): |
| | self.character_name = character_name |
| | self.client = OpenAIClient.get_client() |
| | |
| | |
| | os.makedirs(Config.VECTOR_DB_PATH, exist_ok=True) |
| | |
| | try: |
| | self.chroma_client = chromadb.Client(Settings( |
| | persist_directory=Config.VECTOR_DB_PATH, |
| | anonymized_telemetry=False |
| | )) |
| | except: |
| | |
| | self.chroma_client = chromadb.PersistentClient( |
| | path=Config.VECTOR_DB_PATH |
| | ) |
| | |
| | |
| | collection_name = f"char_{character_name.replace(' ', '_').lower()}" |
| | collection_name = collection_name[:63] |
| | |
| | try: |
| | self.collection = self.chroma_client.get_or_create_collection( |
| | name=collection_name, |
| | metadata={"character": character_name} |
| | ) |
| | except Exception as e: |
| | print(f"创建集合时出错: {e}") |
| | |
| | collection_name = f"char_{hash(character_name) % 10000}" |
| | self.collection = self.chroma_client.get_or_create_collection( |
| | name=collection_name, |
| | metadata={"character": character_name} |
| | ) |
| | |
| | def add_text_chunks(self, chunks: List[Dict], character_chunks: List[int]): |
| | """添加与角色相关的文本块 |
| | |
| | Args: |
| | chunks: 所有文本块 |
| | character_chunks: 角色出现的文本块ID列表 |
| | """ |
| | |
| | documents = [] |
| | metadatas = [] |
| | ids = [] |
| | |
| | for chunk_id in character_chunks: |
| | if chunk_id < len(chunks): |
| | chunk = chunks[chunk_id] |
| | documents.append(chunk['text']) |
| | metadatas.append({ |
| | 'chunk_id': chunk_id, |
| | 'position': chunk['start'] |
| | }) |
| | ids.append(f"chunk_{chunk_id}") |
| | |
| | if documents: |
| | try: |
| | |
| | batch_size = 100 |
| | for i in range(0, len(documents), batch_size): |
| | batch_docs = documents[i:i+batch_size] |
| | batch_metas = metadatas[i:i+batch_size] |
| | batch_ids = ids[i:i+batch_size] |
| | |
| | self.collection.add( |
| | documents=batch_docs, |
| | metadatas=batch_metas, |
| | ids=batch_ids |
| | ) |
| | |
| | print(f"已为 {self.character_name} 添加 {len(documents)} 个文本块到向量库") |
| | except Exception as e: |
| | print(f"添加文本块到向量库失败: {e}") |
| | print("将继续运行,但不使用记忆功能") |
| | |
| | def search_relevant_context(self, query: str, n_results: int = None) -> List[str]: |
| | """检索与查询相关的上下文 |
| | |
| | Args: |
| | query: 查询文本 |
| | n_results: 返回结果数量 |
| | |
| | Returns: |
| | 相关文本片段列表 |
| | """ |
| | |
| | n_results = n_results or Config.MAX_MEMORY_RETRIEVAL |
| | |
| | try: |
| | collection_count = self.collection.count() |
| | if collection_count == 0: |
| | return [] |
| | |
| | actual_n_results = min(n_results, collection_count) |
| | |
| | results = self.collection.query( |
| | query_texts=[query], |
| | n_results=actual_n_results |
| | ) |
| | |
| | if results and results['documents']: |
| | return results['documents'][0] |
| | return [] |
| | |
| | except Exception as e: |
| | print(f"检索失败: {e}") |
| | return [] |
| | |
| | def get_embedding(self, text: str) -> List[float]: |
| | """获取文本嵌入向量 |
| | |
| | Args: |
| | text: 输入文本 |
| | |
| | Returns: |
| | 嵌入向量 |
| | """ |
| | try: |
| | response = self.client.embeddings.create( |
| | model=Config.EMBEDDING_MODEL, |
| | input=text |
| | ) |
| | return response.data[0].embedding |
| | except Exception as e: |
| | print(f"获取嵌入失败: {e}") |
| | return [] |
| | |
| | def get_statistics(self) -> Dict: |
| | """获取记忆库统计信息 |
| | |
| | Returns: |
| | 统计信息字典 |
| | """ |
| | try: |
| | count = self.collection.count() |
| | return { |
| | 'character': self.character_name, |
| | 'chunk_count': count, |
| | 'collection_name': self.collection.name |
| | } |
| | except: |
| | return { |
| | 'character': self.character_name, |
| | 'chunk_count': 0, |
| | 'collection_name': 'unknown' |
| | } |
| | |
| | def clear(self): |
| | """清空记忆库""" |
| | try: |
| | |
| | self.chroma_client.delete_collection(self.collection.name) |
| | print(f"已清空 {self.character_name} 的记忆库") |
| | except Exception as e: |
| | print(f"清空记忆库失败: {e}") |