scrapeRL / backend /app /memory /short_term.py
NeerajCodz's picture
fix: replace deprecated datetime.utcnow with timezone-aware
bfe0e24
"""Short-term memory for episode-scoped data storage."""
from __future__ import annotations
import asyncio
from collections import OrderedDict
from datetime import datetime, timezone
from typing import Any, Generic, TypeVar
from pydantic import BaseModel, Field
T = TypeVar("T")
def _utc_now() -> datetime:
"""Return current UTC datetime."""
return datetime.now(timezone.utc)
class MemoryEntry(BaseModel, Generic[T]):
"""A single memory entry with metadata."""
key: str
value: Any
created_at: datetime = Field(default_factory=_utc_now)
updated_at: datetime = Field(default_factory=datetime.utcnow)
access_count: int = 0
tags: list[str] = Field(default_factory=list)
model_config = {"arbitrary_types_allowed": True}
class ShortTermMemory:
"""
Episode-scoped memory using dictionary-based storage.
This memory layer is designed for transient data that should persist
only within a single episode. It automatically clears when the episode
resets.
Attributes:
max_size: Maximum number of entries allowed.
_store: Internal storage dictionary.
_episode_id: Current episode identifier.
"""
def __init__(self, max_size: int = 100) -> None:
"""
Initialize short-term memory.
Args:
max_size: Maximum number of entries to store. Defaults to 100.
"""
self.max_size = max_size
self._store: OrderedDict[str, MemoryEntry] = OrderedDict()
self._episode_id: str | None = None
self._lock = asyncio.Lock()
@property
def episode_id(self) -> str | None:
"""Get the current episode ID."""
return self._episode_id
@property
def size(self) -> int:
"""Get the current number of entries."""
return len(self._store)
async def set_episode(self, episode_id: str) -> None:
"""
Set the current episode ID and clear existing memory.
Args:
episode_id: Unique identifier for the new episode.
"""
async with self._lock:
if self._episode_id != episode_id:
self._store.clear()
self._episode_id = episode_id
async def set(
self,
key: str,
value: Any,
tags: list[str] | None = None,
) -> MemoryEntry:
"""
Store a value in short-term memory.
Args:
key: Unique key for the entry.
value: Value to store.
tags: Optional tags for categorization.
Returns:
The created or updated memory entry.
Raises:
ValueError: If max_size would be exceeded for a new key.
"""
async with self._lock:
now = datetime.now(timezone.utc)
if key in self._store:
entry = self._store[key]
entry.value = value
entry.updated_at = now
if tags is not None:
entry.tags = tags
# Move to end (most recent)
self._store.move_to_end(key)
else:
# Check capacity
if len(self._store) >= self.max_size:
# Remove oldest entry
self._store.popitem(last=False)
entry = MemoryEntry(
key=key,
value=value,
created_at=now,
updated_at=now,
tags=tags or [],
)
self._store[key] = entry
return entry
async def get(self, key: str, default: Any = None) -> Any:
"""
Retrieve a value from short-term memory.
Args:
key: Key to look up.
default: Default value if key not found.
Returns:
The stored value or default.
"""
async with self._lock:
entry = self._store.get(key)
if entry is None:
return default
entry.access_count += 1
return entry.value
async def get_entry(self, key: str) -> MemoryEntry | None:
"""
Retrieve a full memory entry with metadata.
Args:
key: Key to look up.
Returns:
The memory entry or None if not found.
"""
async with self._lock:
entry = self._store.get(key)
if entry:
entry.access_count += 1
return entry
async def delete(self, key: str) -> bool:
"""
Delete an entry from memory.
Args:
key: Key to delete.
Returns:
True if the key was found and deleted, False otherwise.
"""
async with self._lock:
if key in self._store:
del self._store[key]
return True
return False
async def clear(self) -> int:
"""
Clear all entries from memory.
Returns:
Number of entries that were cleared.
"""
async with self._lock:
count = len(self._store)
self._store.clear()
return count
async def list_keys(self, tag: str | None = None) -> list[str]:
"""
List all keys in memory, optionally filtered by tag.
Args:
tag: Optional tag to filter by.
Returns:
List of matching keys.
"""
async with self._lock:
if tag is None:
return list(self._store.keys())
return [k for k, v in self._store.items() if tag in v.tags]
async def get_by_tag(self, tag: str) -> dict[str, Any]:
"""
Retrieve all entries with a specific tag.
Args:
tag: Tag to filter by.
Returns:
Dictionary of key-value pairs matching the tag.
"""
async with self._lock:
return {
k: v.value for k, v in self._store.items() if tag in v.tags
}
async def exists(self, key: str) -> bool:
"""
Check if a key exists in memory.
Args:
key: Key to check.
Returns:
True if key exists, False otherwise.
"""
async with self._lock:
return key in self._store
async def get_stats(self) -> dict[str, Any]:
"""
Get statistics about the memory store.
Returns:
Dictionary with memory statistics.
"""
async with self._lock:
return {
"size": len(self._store),
"max_size": self.max_size,
"episode_id": self._episode_id,
"keys": list(self._store.keys()),
"utilization": len(self._store) / self.max_size if self.max_size > 0 else 0,
}