| | import json |
| | import hashlib |
| | import asyncio |
| | from uuid import uuid4 |
| | from typing import Union, List, Dict, Optional, Tuple |
| |
|
| | from pydantic import Field |
| |
|
| | from .memory import BaseMemory |
| | from evoagentx.rag import RAGConfig, RAGEngine |
| | from evoagentx.rag.schema import Corpus, Chunk, ChunkMetadata, Query, RagResult |
| | from evoagentx.storages.base import StorageHandler |
| | from evoagentx.core.message import Message |
| | from evoagentx.core.logging import logger |
| |
|
| | class LongTermMemory(BaseMemory): |
| | """ |
| | Manages long-term storage and retrieval of memories, integrating with RAGEngine for indexing |
| | and StorageHandler for persistence. |
| | """ |
| | storage_handler: StorageHandler = Field(..., description="Handler for persistent storage") |
| | rag_config: RAGConfig = Field(..., description="Configuration for RAG engine") |
| | rag_engine: RAGEngine = Field(default=None, description="RAG engine for indexing and retrieval") |
| | memory_table: str = Field(default="memory", description="Database table for storing memories") |
| | default_corpus_id: Optional[str] = Field(default=None, description="Default corpus ID for memory indexing") |
| |
|
| | def init_module(self): |
| | """Initialize the RAG engine and memory indices.""" |
| | super().init_module() |
| | if self.rag_engine is None: |
| | self.rag_engine = RAGEngine(config=self.rag_config, storage_handler=self.storage_handler) |
| | if self.default_corpus_id is None: |
| | self.default_corpus_id = str(uuid4()) |
| | logger.info(f"Initialized LongTermMemory with corpus_id {self.default_corpus_id}") |
| |
|
| | def _create_memory_chunk(self, message: Message, memory_id: str) -> Chunk: |
| | """Convert a Message to a Chunk for RAG indexing.""" |
| | metadata = ChunkMetadata( |
| | corpus_id=self.default_corpus_id, |
| | memory_id=memory_id, |
| | timestamp=message.timestamp, |
| | action=message.action, |
| | wf_goal=message.wf_goal, |
| | agent=message.agent, |
| | msg_type=message.msg_type.value if message.msg_type else None, |
| | prompt=message.prompt, |
| | next_actions=message.next_actions, |
| | wf_task=message.wf_task, |
| | wf_task_desc=message.wf_task_desc, |
| | message_id=message.message_id, |
| | content=json.dumps(message.content), |
| | ) |
| | return Chunk( |
| | chunk_id=memory_id, |
| | text=str(message.content), |
| | metadata=metadata, |
| | start_char_idx=0, |
| | end_char_idx=len(str(message.content)), |
| | ) |
| |
|
| | def _chunk_to_message(self, chunk: Chunk) -> Message: |
| | """Convert a Chunk to a Message object.""" |
| | return Message( |
| | content=chunk.metadata.content, |
| | action=chunk.metadata.action, |
| | wf_goal=chunk.metadata.wf_goal, |
| | timestamp=chunk.metadata.timestamp, |
| | agent=chunk.metadata.agent, |
| | msg_type=chunk.metadata.msg_type, |
| | prompt=chunk.metadata.prompt, |
| | next_actions=chunk.metadata.next_actions, |
| | wf_task=chunk.metadata.wf_task, |
| | wf_task_desc=chunk.metadata.wf_task_desc, |
| | message_id=chunk.metadata.message_id, |
| | ) |
| |
|
| | def add(self, messages: Union[Message, str, List[Union[Message, str]]]) -> List[str]: |
| | """Store messages in memory and index them in RAGEngine, returning memory_ids.""" |
| | if not isinstance(messages, list): |
| | messages = [messages] |
| | messages = [Message(content=msg) if isinstance(msg, str) else msg for msg in messages] |
| | messages = [msg for msg in messages if msg.content] |
| |
|
| | if not messages: |
| | logger.warning("No valid messages to add") |
| | return [] |
| |
|
| | |
| | existing_hashes = { |
| | record["content_hash"] |
| | for record in self.storage_handler.load(tables=[self.memory_table]).get(self.memory_table, []) |
| | if "content_hash" in record |
| | } |
| | memory_ids = [str(uuid4()) for _ in messages] |
| | final_messages = [] |
| | final_memory_ids = [] |
| | final_chunks = [] |
| |
|
| | for msg, memory_id in zip(messages, memory_ids): |
| | content_hash = hashlib.sha256(str(msg.content).encode()).hexdigest() |
| | if content_hash in existing_hashes: |
| | logger.info(f"Duplicate message found (hash): {msg.content[:50]}...") |
| | existing_id = next( |
| | (r["memory_id"] for r in self.storage_handler.load(tables=[self.memory_table]).get(self.memory_table, []) |
| | if r.get("content_hash") == content_hash), None |
| | ) |
| | if existing_id: |
| | final_memory_ids.append(existing_id) |
| | continue |
| | final_messages.append(msg) |
| | final_memory_ids.append(memory_id) |
| | chunk = self._create_memory_chunk(msg, memory_id) |
| | chunk.metadata.content_hash = content_hash |
| | final_chunks.append(chunk) |
| |
|
| | if not final_chunks: |
| | logger.info("No messages added after deduplication") |
| | return final_memory_ids |
| |
|
| | |
| | for msg in final_messages: |
| | super().add_message(msg) |
| |
|
| | |
| | corpus = Corpus(chunks=final_chunks, corpus_id=self.default_corpus_id) |
| | chunk_ids = self.rag_engine.add(index_type=self.rag_config.index.index_type, nodes=corpus, corpus_id=self.default_corpus_id) |
| | if not chunk_ids: |
| | logger.error("Failed to index memories") |
| | return final_memory_ids |
| |
|
| | return final_memory_ids |
| |
|
| | async def get(self, memory_ids: Union[str, List[str]], return_chunk: bool = True) -> List[Tuple[Union[Chunk, Message], str]]: |
| | """Retrieve memories by memory_ids, returning (Message/Chunk, memory_id) tuples.""" |
| | if not isinstance(memory_ids, list): |
| | memory_ids = [memory_ids] |
| |
|
| | if not memory_ids: |
| | logger.warning("No memory_ids provided for get") |
| | return [] |
| |
|
| | try: |
| | chunks = await self.rag_engine.aget( |
| | corpus_id=self.default_corpus_id, |
| | index_type=self.rag_config.index.index_type, |
| | node_ids=memory_ids |
| | ) |
| | results = [(self._chunk_to_message(chunk), chunk.metadata.memory_id) if not return_chunk else (chunk, chunk.metadata.memory_id) |
| | for chunk in chunks if chunk] |
| | logger.info(f"Retrieved {len(results)} memories for memory_ids: {memory_ids}") |
| | return results |
| | except Exception as e: |
| | logger.error(f"Failed to get memories: {str(e)}") |
| | return [] |
| |
|
| | def delete(self, memory_ids: Union[str, List[str]]) -> List[bool]: |
| | """Delete memories by memory_ids, returning success status for each.""" |
| | if not isinstance(memory_ids, list): |
| | memory_ids = [memory_ids] |
| |
|
| | if not memory_ids: |
| | logger.warning("No memory_ids provided for deletion") |
| | return [] |
| |
|
| | successes = [False] * len(memory_ids) |
| | valid_memory_ids = [] |
| |
|
| | existing_chunks = asyncio.run(self.get(memory_ids, return_chunk=True)) |
| | for idx, (chunk, mid) in enumerate(existing_chunks): |
| | if chunk: |
| | valid_memory_ids.append(mid) |
| | super().remove_message(self._chunk_to_message(chunk)) |
| | successes[idx] = True |
| |
|
| | if not valid_memory_ids: |
| | logger.info("No memories found for deletion") |
| | return successes |
| |
|
| | |
| | self.rag_engine.delete( |
| | corpus_id=self.default_corpus_id, |
| | index_type=self.rag_config.index.index_type, |
| | node_ids=valid_memory_ids |
| | ) |
| |
|
| | return successes |
| |
|
| | def update(self, updates: Union[Tuple[str, Union[Message, str]], List[Tuple[str, Union[Message, str]]]]) -> List[bool]: |
| | """Update memories with new content, returning success status for each.""" |
| | if not isinstance(updates, list): |
| | updates = [updates] |
| | updates = [(mid, Message(content=msg) if isinstance(msg, str) else msg) for mid, msg in updates] |
| | updates_dict = {mid: msg for mid, msg in updates if msg.content} |
| |
|
| | if not updates_dict: |
| | logger.warning("No valid updates provided") |
| | return [] |
| |
|
| | memory_ids = list(updates_dict.keys()) |
| | existing_memories = asyncio.run(self.get(memory_ids, return_chunk=False)) |
| | existing_dict = {mid: msg for msg, mid in existing_memories} |
| |
|
| | successes = [False] * len(updates) |
| | final_updates = [] |
| | final_memory_ids = [] |
| |
|
| | for mid, msg in updates_dict.items(): |
| | if mid not in existing_dict: |
| | logger.warning(f"No memory found with memory_id {mid}") |
| | continue |
| | final_updates.append((mid, msg)) |
| | final_memory_ids.append(mid) |
| | successes[memory_ids.index(mid)] = True |
| | super().remove_message(existing_dict[mid]) |
| |
|
| | if not final_updates: |
| | logger.info("No memories updated") |
| | return successes |
| |
|
| | chunks = [self._create_memory_chunk(msg, mid) for mid, msg in final_updates] |
| | for msg in [msg for _, msg in final_updates]: |
| | super().add_message(msg) |
| |
|
| | corpus = Corpus(chunks=chunks, corpus_id=self.default_corpus_id) |
| | chunk_ids = self.rag_engine.add(index_type=self.rag_config.index.index_type, nodes=corpus, corpus_id=self.default_corpus_id) |
| | if not chunk_ids: |
| | logger.error(f"Failed to update memories in RAG index: {final_memory_ids}") |
| | return [False] * len(updates) |
| |
|
| | return successes |
| |
|
| | async def search_async(self, query: Union[str, Query], n: Optional[int] = None, |
| | metadata_filters: Optional[Dict] = None, return_chunk=False) -> List[Tuple[Message, str]]: |
| | """Retrieve messages from RAG index asynchronously based on a query, returning messages and memory_ids.""" |
| | if isinstance(query, str): |
| | query_obj = Query( |
| | query_str=query, |
| | top_k=n or self.rag_config.retrieval.top_k, |
| | metadata_filters=metadata_filters or {} |
| | ) |
| | else: |
| | query_obj = query |
| | query_obj.top_k = n or self.rag_config.retrieval.top_k |
| | if metadata_filters: |
| | query_obj.metadata_filters = {**query_obj.metadata_filters, **metadata_filters} if query_obj.metadata_filters else metadata_filters |
| |
|
| | try: |
| | result: RagResult = await self.rag_engine.query_async(query_obj, corpus_id=self.default_corpus_id) |
| | if return_chunk: |
| | return [(chunk, chunk.metadata.memory_id) for chunk in result.corpus.chunks] |
| | else: |
| | messages = [(self._chunk_to_message(chunk), chunk.metadata.memory_id) for chunk in result.corpus.chunks] |
| | logger.info(f"Retrieved {len(messages)} memories for query: {query_obj.query_str}") |
| | return messages[:n] if n else messages |
| | except Exception as e: |
| | logger.error(f"Failed to search memories: {str(e)}") |
| | return [] |
| |
|
| | def search(self, query: Union[str, Query], n: Optional[int] = None, |
| | metadata_filters: Optional[Dict] = None) -> List[Tuple[Message, str]]: |
| | """Synchronous wrapper for searching memories.""" |
| | return asyncio.run(self.search_async(query, n, metadata_filters)) |
| |
|
| | def clear(self) -> None: |
| | """Clear all messages and indices.""" |
| | super().clear() |
| | self.rag_engine.clear(corpus_id=self.default_corpus_id) |
| | logger.info(f"Cleared LongTermMemory with corpus_id {self.default_corpus_id}") |
| |
|
| | def save(self, save_path: Optional[str] = None) -> None: |
| | """Save all indices and memory data to database.""" |
| | self.rag_engine.save(output_path=save_path, corpus_id=self.default_corpus_id, table=self.memory_table) |
| |
|
| | def load(self, save_path: Optional[str] = None) -> List[str]: |
| | """Load memory data from database and reconstruct indices, returning memory_ids.""" |
| | return self.rag_engine.load(source=save_path, corpus_id=self.default_corpus_id, table=self.memory_table) |