| """ |
| Enhanced Memory System for GAIA-Ready AI Agent |
| |
| This module provides an advanced memory system for the AI agent, |
| including short-term, long-term, and working memory components, |
| as well as semantic retrieval capabilities. |
| """ |
|
|
| import os |
| import json |
| from typing import List, Dict, Any, Optional, Union |
| from datetime import datetime |
| import re |
| import numpy as np |
| from collections import defaultdict |
|
|
| try: |
| from sentence_transformers import SentenceTransformer |
| except ImportError: |
| import subprocess |
| subprocess.check_call(["pip", "install", "sentence-transformers"]) |
| from sentence_transformers import SentenceTransformer |
|
|
|
|
| class EnhancedMemoryManager: |
| """ |
| Advanced memory manager for the agent that maintains short-term, long-term, |
| and working memory with semantic retrieval capabilities. |
| """ |
| def __init__(self, use_semantic_search=True): |
| self.short_term_memory = [] |
| self.long_term_memory = [] |
| self.working_memory = {} |
| self.max_short_term_items = 15 |
| self.max_long_term_items = 100 |
| self.use_semantic_search = use_semantic_search |
| |
| |
| if self.use_semantic_search: |
| try: |
| self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') |
| self.memory_embeddings = [] |
| except Exception as e: |
| print(f"Warning: Could not initialize semantic search: {str(e)}") |
| self.use_semantic_search = False |
| |
| |
| self.memory_file = "agent_memory.json" |
| self.load_memories() |
| |
| def add_to_short_term(self, item: Dict[str, Any]) -> None: |
| """Add an item to short-term memory, maintaining size limit""" |
| |
| if "content" not in item: |
| raise ValueError("Memory item must have 'content' field") |
| |
| if "timestamp" not in item: |
| item["timestamp"] = datetime.now().isoformat() |
| |
| if "type" not in item: |
| item["type"] = "general" |
| |
| self.short_term_memory.append(item) |
| |
| |
| if self.use_semantic_search: |
| try: |
| content = item.get("content", "") |
| embedding = self.embedding_model.encode(content) |
| self.memory_embeddings.append((embedding, len(self.short_term_memory) - 1, "short_term")) |
| except Exception as e: |
| print(f"Warning: Could not create embedding for memory item: {str(e)}") |
| |
| |
| if len(self.short_term_memory) > self.max_short_term_items: |
| removed_item = self.short_term_memory.pop(0) |
| |
| if self.use_semantic_search: |
| self.memory_embeddings = [(emb, idx, mem_type) for emb, idx, mem_type in self.memory_embeddings |
| if not (mem_type == "short_term" and idx == 0)] |
| |
| self.memory_embeddings = [(emb, idx-1 if mem_type == "short_term" else idx, mem_type) |
| for emb, idx, mem_type in self.memory_embeddings] |
| |
| |
| self.save_memories() |
| |
| def add_to_long_term(self, item: Dict[str, Any]) -> None: |
| """Add an important item to long-term memory, maintaining size limit""" |
| |
| if "content" not in item: |
| raise ValueError("Memory item must have 'content' field") |
| |
| if "timestamp" not in item: |
| item["timestamp"] = datetime.now().isoformat() |
| |
| if "type" not in item: |
| item["type"] = "general" |
| |
| |
| if "importance" not in item: |
| |
| content_length = len(item.get("content", "")) |
| type_importance = { |
| "final_answer": 0.9, |
| "key_fact": 0.8, |
| "reasoning": 0.7, |
| "general": 0.5 |
| } |
| item["importance"] = min(1.0, (content_length / 1000) * type_importance.get(item["type"], 0.5)) |
| |
| self.long_term_memory.append(item) |
| |
| |
| if self.use_semantic_search: |
| try: |
| content = item.get("content", "") |
| embedding = self.embedding_model.encode(content) |
| self.memory_embeddings.append((embedding, len(self.long_term_memory) - 1, "long_term")) |
| except Exception as e: |
| print(f"Warning: Could not create embedding for memory item: {str(e)}") |
| |
| |
| self.long_term_memory.sort(key=lambda x: x.get("importance", 0), reverse=True) |
| |
| |
| if len(self.long_term_memory) > self.max_long_term_items: |
| |
| removed_item = self.long_term_memory.pop() |
| |
| if self.use_semantic_search: |
| self.memory_embeddings = [(emb, idx, mem_type) for emb, idx, mem_type in self.memory_embeddings |
| if not (mem_type == "long_term" and idx == len(self.long_term_memory))] |
| |
| |
| long_term_embeddings = [] |
| for i, item in enumerate(self.long_term_memory): |
| content = item.get("content", "") |
| embedding = self.embedding_model.encode(content) |
| long_term_embeddings.append((embedding, i, "long_term")) |
| |
| |
| self.memory_embeddings = [(emb, idx, mem_type) for emb, idx, mem_type in self.memory_embeddings |
| if mem_type == "short_term"] + long_term_embeddings |
| |
| |
| self.save_memories() |
| |
| def store_in_working_memory(self, key: str, value: Any) -> None: |
| """Store a value in working memory under the specified key""" |
| self.working_memory[key] = value |
| |
| |
| def get_from_working_memory(self, key: str) -> Optional[Any]: |
| """Retrieve a value from working memory by key""" |
| return self.working_memory.get(key) |
| |
| def clear_working_memory(self) -> None: |
| """Clear the working memory""" |
| self.working_memory = {} |
| |
| def get_relevant_memories(self, query: str, max_results: int = 10) -> List[Dict[str, Any]]: |
| """ |
| Retrieve memories relevant to the current query |
| |
| Args: |
| query: The query to find relevant memories for |
| max_results: Maximum number of results to return |
| |
| Returns: |
| List of relevant memory items |
| """ |
| if self.use_semantic_search: |
| try: |
| |
| query_embedding = self.embedding_model.encode(query) |
| |
| |
| similarities = [] |
| for embedding, idx, mem_type in self.memory_embeddings: |
| similarity = np.dot(query_embedding, embedding) / (np.linalg.norm(query_embedding) * np.linalg.norm(embedding)) |
| similarities.append((similarity, idx, mem_type)) |
| |
| |
| similarities.sort(reverse=True) |
| |
| |
| relevant_memories = [] |
| for similarity, idx, mem_type in similarities[:max_results]: |
| if mem_type == "short_term": |
| memory = self.short_term_memory[idx] |
| else: |
| memory = self.long_term_memory[idx] |
| |
| |
| memory_with_score = memory.copy() |
| memory_with_score["relevance_score"] = float(similarity) |
| relevant_memories.append(memory_with_score) |
| |
| return relevant_memories |
| except Exception as e: |
| print(f"Warning: Semantic search failed: {str(e)}. Falling back to keyword search.") |
| return self._keyword_search(query, max_results) |
| else: |
| return self._keyword_search(query, max_results) |
| |
| def _keyword_search(self, query: str, max_results: int = 10) -> List[Dict[str, Any]]: |
| """ |
| Fallback keyword-based search for relevant memories |
| |
| Args: |
| query: The query to find relevant memories for |
| max_results: Maximum number of results to return |
| |
| Returns: |
| List of relevant memory items |
| """ |
| relevant_memories = [] |
| query_keywords = set(re.findall(r'\b\w+\b', query.lower())) |
| |
| |
| def score_memory(memory): |
| content = memory.get("content", "").lower() |
| content_words = set(re.findall(r'\b\w+\b', content)) |
| |
| |
| matches = len(query_keywords.intersection(content_words)) |
| |
| |
| type_boost = { |
| "final_answer": 2.0, |
| "key_fact": 1.5, |
| "reasoning": 1.2, |
| "general": 1.0 |
| } |
| |
| |
| try: |
| timestamp = datetime.fromisoformat(memory.get("timestamp", "2000-01-01T00:00:00")) |
| now = datetime.now() |
| hours_ago = (now - timestamp).total_seconds() / 3600 |
| recency_factor = max(0.5, 1.0 - (hours_ago / 24)) |
| except: |
| recency_factor = 0.5 |
| |
| |
| score = matches * type_boost.get(memory.get("type", "general"), 1.0) * recency_factor |
| |
| return score |
| |
| |
| scored_memories = [] |
| |
| |
| for memory in self.long_term_memory: |
| score = score_memory(memory) |
| if score > 0: |
| memory_with_score = memory.copy() |
| memory_with_score["relevance_score"] = score |
| scored_memories.append((score, memory_with_score)) |
| |
| |
| for memory in self.short_term_memory: |
| score = score_memory(memory) |
| if score > 0: |
| memory_with_score = memory.copy() |
| memory_with_score["relevance_score"] = score |
| scored_memories.append((score, memory_with_score)) |
| |
| |
| scored_memories.sort(reverse=True, key=lambda x: x[0]) |
| relevant_memories = [memory for _, memory in scored_memories[:max_results]] |
| |
| return relevant_memories |
| |
| def get_memory_summary(self) -> str: |
| """Get a summary of the current memory state for the agent""" |
| |
| recent_short_term = self.short_term_memory[-5:] if self.short_term_memory else [] |
| short_term_summary = "\n".join([f"- [{m.get('type', 'general')}] {m.get('content', '')[:100]}..." |
| for m in recent_short_term]) |
| |
| |
| important_long_term = sorted(self.long_term_memory, |
| key=lambda x: x.get("importance", 0), |
| reverse=True)[:5] if self.long_term_memory else [] |
| long_term_summary = "\n".join([f"- [{m.get('type', 'general')}] {m.get('content', '')[:100]}..." |
| for m in important_long_term]) |
| |
| |
| working_memory_summary = "\n".join([f"- {k}: {str(v)[:50]}..." if isinstance(v, str) and len(str(v)) > 50 |
| else f"- {k}: {v}" for k, v in self.working_memory.items()]) |
| |
| return f""" |
| MEMORY SUMMARY: |
| -------------- |
| Recent Short-Term Memory: |
| {short_term_summary if short_term_summary else "No recent short-term memories."} |
| |
| Important Long-Term Memory: |
| {long_term_summary if long_term_summary else "No important long-term memories."} |
| |
| Working Memory: |
| {working_memory_summary if working_memory_summary else "Working memory is empty."} |
| """ |
| |
| def save_memories(self) -> None: |
| """Save memories to disk for persistence""" |
| try: |
| |
| memories = { |
| "short_term": self.short_term_memory, |
| "long_term": self.long_term_memory, |
| "last_updated": datetime.now().isoformat() |
| } |
| |
| with open(self.memory_file, 'w') as f: |
| json.dump(memories, f, indent=2) |
| except Exception as e: |
| print(f"Warning: Could not save memories: {str(e)}") |
| |
| def load_memories(self) -> None: |
| """Load memories from disk if available""" |
| try: |
| if os.path.exists(self.memory_file): |
| with open(self.memory_file, 'r') as f: |
| memories = json.load(f) |
| |
| self.short_term_memory = memories.get("short_term", []) |
| self.long_term_memory = memories.get("long_term", []) |
| |
| |
| if self.use_semantic_search: |
| self.memory_embeddings = [] |
| |
| |
| for i, memory in enumerate(self.short_term_memory): |
| try: |
| content = memory.get("content", "") |
| embedding = self.embedding_model.encode(content) |
| self.memory_embeddings.append((embedding, i, "short_term")) |
| except Exception as e: |
| print(f"Warning: Could not create embedding for memory item: {str(e)}") |
| |
| |
| for i, memory in enumerate(self.long_term_memory): |
| try: |
| content = memory.get("content", "") |
| embedding = self.embedding_model.encode(content) |
| self.memory_embeddings.append((embedding, i, "long_term")) |
| except Exception as e: |
| print(f"Warning: Could not create embedding for memory item: {str(e)}") |
| |
| print(f"Loaded {len(self.short_term_memory)} short-term and {len(self.long_term_memory)} long-term memories.") |
| except Exception as e: |
| print(f"Warning: Could not load memories: {str(e)}") |
| |
| def forget_old_memories(self, days_threshold: int = 30) -> None: |
| """ |
| Remove memories older than the specified threshold |
| |
| Args: |
| days_threshold: Age threshold in days |
| """ |
| try: |
| now = datetime.now() |
| threshold = days_threshold * 24 * 60 * 60 |
| |
| |
| new_short_term = [] |
| for i, memory in enumerate(self.short_term_memory): |
| try: |
| timestamp = datetime.fromisoformat(memory.get("timestamp", "2000-01-01T00:00:00")) |
| age = (now - timestamp).total_seconds() |
| if age < threshold: |
| new_short_term.append(memory) |
| except: |
| |
| new_short_term.append(memory) |
| |
| |
| new_long_term = [] |
| for i, memory in enumerate(self.long_term_memory): |
| try: |
| timestamp = datetime.fromisoformat(memory.get("timestamp", "2000-01-01T00:00:00")) |
| age = (now - timestamp).total_seconds() |
| |
| importance = memory.get("importance", 0.5) |
| |
| adjusted_threshold = threshold * (1 + importance) |
| if age < adjusted_threshold: |
| new_long_term.append(memory) |
| except: |
| |
| new_long_term.append(memory) |
| |
| |
| removed_short_term = len(self.short_term_memory) - len(new_short_term) |
| removed_long_term = len(self.long_term_memory) - len(new_long_term) |
| |
| self.short_term_memory = new_short_term |
| self.long_term_memory = new_long_term |
| |
| |
| if self.use_semantic_search: |
| self.memory_embeddings = [] |
| |
| |
| for i, memory in enumerate(self.short_term_memory): |
| try: |
| content = memory.get("content", "") |
| embedding = self.embedding_model.encode(content) |
| self.memory_embeddings.append((embedding, i, "short_term")) |
| except Exception as e: |
| print(f"Warning: Could not create embedding for memory item: {str(e)}") |
| |
| |
| for i, memory in enumerate(self.long_term_memory): |
| try: |
| content = memory.get("content", "") |
| embedding = self.embedding_model.encode(content) |
| self.memory_embeddings.append((embedding, i, "long_term")) |
| except Exception as e: |
| print(f"Warning: Could not create embedding for memory item: {str(e)}") |
| |
| |
| self.save_memories() |
| |
| print(f"Forgot {removed_short_term} short-term and {removed_long_term} long-term memories older than {days_threshold} days.") |
| except Exception as e: |
| print(f"Warning: Could not forget old memories: {str(e)}") |
|
|
|
|
| |
| if __name__ == "__main__": |
| |
| memory_manager = EnhancedMemoryManager(use_semantic_search=True) |
| |
| |
| memory_manager.add_to_short_term({ |
| "type": "query", |
| "content": "What is the capital of France?", |
| "timestamp": datetime.now().isoformat() |
| }) |
| |
| memory_manager.add_to_long_term({ |
| "type": "key_fact", |
| "content": "Paris is the capital of France with a population of about 2.2 million people.", |
| "timestamp": datetime.now().isoformat() |
| }) |
| |
| memory_manager.store_in_working_memory("current_task", "Finding information about France") |
| |
| |
| relevant_memories = memory_manager.get_relevant_memories("What is the population of Paris?") |
| print("\nRelevant memories for 'What is the population of Paris?':") |
| for memory in relevant_memories: |
| print(f"- Score: {memory.get('relevance_score', 0):.2f}, Content: {memory.get('content', '')}") |
| |
| |
| print("\nMemory Summary:") |
| print(memory_manager.get_memory_summary()) |
|
|