Spaces:
Running
Running
| """Short-term memory for episode-scoped data storage.""" | |
| from __future__ import annotations | |
| import asyncio | |
| from collections import OrderedDict | |
| from datetime import datetime, timezone | |
| from typing import Any, Generic, TypeVar | |
| from pydantic import BaseModel, Field | |
| T = TypeVar("T") | |
| def _utc_now() -> datetime: | |
| """Return current UTC datetime.""" | |
| return datetime.now(timezone.utc) | |
| class MemoryEntry(BaseModel, Generic[T]): | |
| """A single memory entry with metadata.""" | |
| key: str | |
| value: Any | |
| created_at: datetime = Field(default_factory=_utc_now) | |
| updated_at: datetime = Field(default_factory=datetime.utcnow) | |
| access_count: int = 0 | |
| tags: list[str] = Field(default_factory=list) | |
| model_config = {"arbitrary_types_allowed": True} | |
| class ShortTermMemory: | |
| """ | |
| Episode-scoped memory using dictionary-based storage. | |
| This memory layer is designed for transient data that should persist | |
| only within a single episode. It automatically clears when the episode | |
| resets. | |
| Attributes: | |
| max_size: Maximum number of entries allowed. | |
| _store: Internal storage dictionary. | |
| _episode_id: Current episode identifier. | |
| """ | |
| def __init__(self, max_size: int = 100) -> None: | |
| """ | |
| Initialize short-term memory. | |
| Args: | |
| max_size: Maximum number of entries to store. Defaults to 100. | |
| """ | |
| self.max_size = max_size | |
| self._store: OrderedDict[str, MemoryEntry] = OrderedDict() | |
| self._episode_id: str | None = None | |
| self._lock = asyncio.Lock() | |
| def episode_id(self) -> str | None: | |
| """Get the current episode ID.""" | |
| return self._episode_id | |
| def size(self) -> int: | |
| """Get the current number of entries.""" | |
| return len(self._store) | |
| async def set_episode(self, episode_id: str) -> None: | |
| """ | |
| Set the current episode ID and clear existing memory. | |
| Args: | |
| episode_id: Unique identifier for the new episode. | |
| """ | |
| async with self._lock: | |
| if self._episode_id != episode_id: | |
| self._store.clear() | |
| self._episode_id = episode_id | |
| async def set( | |
| self, | |
| key: str, | |
| value: Any, | |
| tags: list[str] | None = None, | |
| ) -> MemoryEntry: | |
| """ | |
| Store a value in short-term memory. | |
| Args: | |
| key: Unique key for the entry. | |
| value: Value to store. | |
| tags: Optional tags for categorization. | |
| Returns: | |
| The created or updated memory entry. | |
| Raises: | |
| ValueError: If max_size would be exceeded for a new key. | |
| """ | |
| async with self._lock: | |
| now = datetime.now(timezone.utc) | |
| if key in self._store: | |
| entry = self._store[key] | |
| entry.value = value | |
| entry.updated_at = now | |
| if tags is not None: | |
| entry.tags = tags | |
| # Move to end (most recent) | |
| self._store.move_to_end(key) | |
| else: | |
| # Check capacity | |
| if len(self._store) >= self.max_size: | |
| # Remove oldest entry | |
| self._store.popitem(last=False) | |
| entry = MemoryEntry( | |
| key=key, | |
| value=value, | |
| created_at=now, | |
| updated_at=now, | |
| tags=tags or [], | |
| ) | |
| self._store[key] = entry | |
| return entry | |
| async def get(self, key: str, default: Any = None) -> Any: | |
| """ | |
| Retrieve a value from short-term memory. | |
| Args: | |
| key: Key to look up. | |
| default: Default value if key not found. | |
| Returns: | |
| The stored value or default. | |
| """ | |
| async with self._lock: | |
| entry = self._store.get(key) | |
| if entry is None: | |
| return default | |
| entry.access_count += 1 | |
| return entry.value | |
| async def get_entry(self, key: str) -> MemoryEntry | None: | |
| """ | |
| Retrieve a full memory entry with metadata. | |
| Args: | |
| key: Key to look up. | |
| Returns: | |
| The memory entry or None if not found. | |
| """ | |
| async with self._lock: | |
| entry = self._store.get(key) | |
| if entry: | |
| entry.access_count += 1 | |
| return entry | |
| async def delete(self, key: str) -> bool: | |
| """ | |
| Delete an entry from memory. | |
| Args: | |
| key: Key to delete. | |
| Returns: | |
| True if the key was found and deleted, False otherwise. | |
| """ | |
| async with self._lock: | |
| if key in self._store: | |
| del self._store[key] | |
| return True | |
| return False | |
| async def clear(self) -> int: | |
| """ | |
| Clear all entries from memory. | |
| Returns: | |
| Number of entries that were cleared. | |
| """ | |
| async with self._lock: | |
| count = len(self._store) | |
| self._store.clear() | |
| return count | |
| async def list_keys(self, tag: str | None = None) -> list[str]: | |
| """ | |
| List all keys in memory, optionally filtered by tag. | |
| Args: | |
| tag: Optional tag to filter by. | |
| Returns: | |
| List of matching keys. | |
| """ | |
| async with self._lock: | |
| if tag is None: | |
| return list(self._store.keys()) | |
| return [k for k, v in self._store.items() if tag in v.tags] | |
| async def get_by_tag(self, tag: str) -> dict[str, Any]: | |
| """ | |
| Retrieve all entries with a specific tag. | |
| Args: | |
| tag: Tag to filter by. | |
| Returns: | |
| Dictionary of key-value pairs matching the tag. | |
| """ | |
| async with self._lock: | |
| return { | |
| k: v.value for k, v in self._store.items() if tag in v.tags | |
| } | |
| async def exists(self, key: str) -> bool: | |
| """ | |
| Check if a key exists in memory. | |
| Args: | |
| key: Key to check. | |
| Returns: | |
| True if key exists, False otherwise. | |
| """ | |
| async with self._lock: | |
| return key in self._store | |
| async def get_stats(self) -> dict[str, Any]: | |
| """ | |
| Get statistics about the memory store. | |
| Returns: | |
| Dictionary with memory statistics. | |
| """ | |
| async with self._lock: | |
| return { | |
| "size": len(self._store), | |
| "max_size": self.max_size, | |
| "episode_id": self._episode_id, | |
| "keys": list(self._store.keys()), | |
| "utilization": len(self._store) / self.max_size if self.max_size > 0 else 0, | |
| } | |