| | from collections import defaultdict |
| | from typing import Union, Optional, List, Dict |
| | from collections import deque |
| | from pydantic import Field, PositiveInt, field_validator |
| |
|
| | from ..core.module import BaseModule |
| | from ..core.module_utils import generate_id, get_timestamp |
| | from ..core.message import Message |
| | from ..utils.utils import safe_remove |
| |
|
| | class BaseMemory(BaseModule): |
| | """Base class for memory implementations in the EvoAgentX framework. |
| | |
| | BaseMemory provides core functionality for storing, retrieving, and |
| | filtering messages. It maintains a chronological list of messages while |
| | also providing indices for efficient retrieval by action or workflow goal. |
| | |
| | Attributes: |
| | messages: List of stored Message objects. |
| | memory_id: Unique identifier for this memory instance. |
| | timestamp: Creation timestamp of this memory instance. |
| | capacity: Maximum number of messages that can be stored, or None for unlimited. |
| | """ |
| |
|
| | messages: List[Message] = Field(default_factory=list) |
| | memory_id: str = Field(default_factory=generate_id) |
| | timestamp: str = Field(default_factory=get_timestamp) |
| | capacity: Optional[PositiveInt] = Field(default=None, description="maximum of messages, None means there is no limit to the message number") |
| |
|
| | def init_module(self): |
| | """Initialize memory indices. |
| | |
| | Creates default dictionaries for indexing messages by action and workflow goal. |
| | """ |
| | self._by_action = defaultdict(list) |
| | self._by_wf_goal = defaultdict(list) |
| |
|
| | @property |
| | def size(self) -> int: |
| | """Returns the current number of messages in memory. |
| | |
| | Returns: |
| | int: Number of messages currently stored. |
| | """ |
| | return len(self.messages) |
| | |
| | def clear(self): |
| | """Clear all messages from memory. |
| | |
| | Removes all messages and resets all indices. |
| | """ |
| | self.messages.clear() |
| | self._by_action.clear() |
| | self._by_wf_goal.clear() |
| | |
| | def remove_message(self, message: Message): |
| | """Remove a single message from memory. |
| | |
| | Removes the specified message from the main message list and all indices. |
| | If the message is not found in memory, no action is taken. |
| | |
| | Args: |
| | message: The message to be removed. The message will be removed from |
| | self.messages, self._by_action, and self._by_wf_goal. |
| | """ |
| | if not message: |
| | return |
| | if message not in self.messages: |
| | return |
| | safe_remove(self.messages, message) |
| | if self._by_action and not message.action: |
| | safe_remove(self._by_action[message.action], message) |
| | if self._by_wf_goal and not message.wf_goal: |
| | safe_remove(self._by_wf_goal[message.wf_goal], message) |
| |
|
| | def add_message(self, message: Message): |
| | """Store a single message in memory. |
| | |
| | Adds the message to the main list and relevant indices if it's not already stored. |
| | |
| | Args: |
| | message (Message): the message to be stored. |
| | """ |
| | if not message: |
| | return |
| | if message in self.messages: |
| | return |
| | self.messages.append(message) |
| | if self._by_action and not message.action: |
| | self._by_action[message.action].append(message) |
| | if self._by_wf_goal and not message.wf_goal: |
| | self._by_wf_goal[message.wf_goal].append(message) |
| | |
| | def add_messages(self, messages: Union[Message, List[Message]], **kwargs): |
| | """ |
| | store (a) message(s) to the memory. |
| | |
| | Args: |
| | messages (Union[Message, List[Message]]): the input messages can be a single message or a list of message. |
| | """ |
| | if not isinstance(messages, list): |
| | messages = [messages] |
| | for message in messages: |
| | self.add_message(message) |
| | |
| | def get(self, n: int=None, **kwargs) -> List[Message]: |
| | """Retrieve recent messages from memory. |
| | |
| | Returns the most recent messages, up to the specified limit. |
| | |
| | Args: |
| | n: The maximum number of messages to return. If None, returns all messages. |
| | **kwargs (Any): Additional parameters (unused in base implementation). |
| | |
| | Returns: |
| | A list of Message objects, ordered from oldest to newest. |
| | |
| | Raises: |
| | AssertionError: If n is negative. |
| | """ |
| | assert n is None or n>=0, "n must be None or a positive int" |
| | messages = self.messages if n is None else self.messages[-n:] |
| | return messages |
| |
|
| | def get_by_type(self, data: Dict[str, list], key: str, n: int = None, **kwargs) -> List[Message]: |
| | """ |
| | Retrieve a list of Message objects from a given data dictionary `data` based on a specified type `key`. |
| | |
| | This function looks up the value associated with `key` in the `data` dictionary, which should be a list of messages. It then returns a subset of these messages according to the specified parameters. |
| | If `n` is provided, it limits the number of messages returned; otherwise, it may return the entire list. Additional keyword arguments (**kwargs) can be used to further filter or process the resulting messages. |
| | |
| | Args: |
| | data (Dict[str, list]): A dictionary where keys are type strings and values are lists of messages. |
| | key (str): The key in `data` identifying the specific list of messages to retrieve. |
| | n (int, optional): The maximum number of messages to return. If not provided, all messages under the given `key` may be returned. |
| | **kwargs (Any): Additional parameters for filtering or processing the messages. |
| | |
| | Returns: |
| | List[Message]: A list of messages corresponding to the given `key`, possibly filtered or truncated according to `n` and other provided keyword arguments. |
| | """ |
| | if not data or key not in data: |
| | return [] |
| | assert n is None or n>=0, "n must be None or a positive int" |
| | messages = data[key] if n is None else data[key][-n:] |
| | return messages |
| | |
| | def get_by_action(self, actions: Union[str, List[str]], n: int=None, **kwargs) -> List[Message]: |
| | """ |
| | return messages triggered by `actions` in the memory. |
| | |
| | Args: |
| | actions: A single action name or list of action names to filter by. |
| | n: Maximum number of messages to return per action. If None, returns all matching messages. |
| | **kwargs (Any): Additional parameters (unused in base implementation). |
| | |
| | Returns: |
| | A list of Message objects, sorted by timestamp. |
| | """ |
| | if isinstance(actions, str): |
| | actions = [actions] |
| | messages = [] |
| | for action in actions: |
| | messages.extend(self.get_by_type(self._by_action, key=action, n=n, **kwargs)) |
| | messages = Message.sort_by_timestamp(messages) |
| | return messages |
| | |
| | def get_by_wf_goal(self, wf_goals: Union[str, List[str]], n: int=None, **kwargs) -> List[Message]: |
| | """ |
| | return messages related to `wf_goals` in the memory. |
| | |
| | Args: |
| | wf_goals: A single workflow goal or list of workflow goals to filter by. |
| | n: Maximum number of messages to return per workflow goal. If None, returns all matching messages. |
| | **kwargs (Any): Additional parameters (unused in base implementation). |
| | |
| | Returns: |
| | A list of Message objects, sorted by timestamp. |
| | """ |
| | if isinstance(wf_goals, str): |
| | wf_goals = [wf_goals] |
| | messages = [] |
| | for wf_goal in wf_goals: |
| | messages.append(self.get_by_type(self._by_wf_goal, key=wf_goal, n=n, **kwargs)) |
| | messages = Message.sort_by_timestamp(messages) |
| | return messages |
| |
|
| |
|
| | class ShortTermMemory(BaseModule): |
| | """ |
| | Short-term memory implementation. |
| | |
| | Stores only the most recent N messages (like a sliding window). |
| | Unlike BaseMemory/LongTermMemory, this is purely in-memory cache |
| | and does not persist to storage_handler or vector DB. |
| | |
| | Attributes: |
| | buffer: Internal deque holding Message objects, capped at max_size. |
| | max_size: Maximum number of messages to retain. |
| | memory_id: Unique identifier for this memory instance. |
| | timestamp: Creation timestamp. |
| | """ |
| |
|
| | buffer: List[Message] = Field(default_factory=list, exclude=True) |
| | max_size: PositiveInt = Field(default=5, description="Maximum number of messages to keep in short-term memory") |
| | memory_id: str = Field(default_factory=generate_id) |
| | timestamp: str = Field(default_factory=get_timestamp) |
| |
|
| | @field_validator("buffer", mode="before") |
| | @classmethod |
| | def ensure_list(cls, v): |
| | """Ensure that the buffer is always a list, even if it is null in the JSON.""" |
| | if v is None: |
| | return [] |
| | return v |
| |
|
| | |
| | def model_post_init(self, __context=None): |
| | """ |
| | Pydantic V2 hook after model initialization. |
| | Convert buffer list → deque, enforce max_size. |
| | """ |
| | self.buffer = deque(self.buffer, maxlen=self.max_size) |
| |
|
| | @property |
| | def size(self) -> int: |
| | """Return current number of messages stored.""" |
| | return len(self.buffer) |
| |
|
| | def clear(self): |
| | """Clear all short-term memory.""" |
| | self.buffer.clear() |
| |
|
| | def add_message(self, message: Message): |
| | """Add a single message to short-term memory.""" |
| | if not message: |
| | return |
| | self.buffer.append(message) |
| |
|
| | def add_messages(self, messages: Union[Message, List[Message]]): |
| | """Add one or multiple messages.""" |
| | if not isinstance(messages, list): |
| | messages = [messages] |
| | for msg in messages: |
| | self.add_message(msg) |
| |
|
| | def get(self, n: Optional[int] = None) -> List[Message]: |
| | """ |
| | Retrieve the most recent n messages (default: all). |
| | |
| | Args: |
| | n: Number of messages to return. If None, return all. |
| | |
| | Returns: |
| | List of Message objects, oldest → newest. |
| | """ |
| | if n is None: |
| | return list(self.buffer) |
| | return list(self.buffer)[-n:] |
| |
|
| | def get_last(self) -> Optional[Message]: |
| | """Return the latest message, or None if empty.""" |
| | return self.buffer[-1] if self.buffer else None |