| | """Deprecated as of LangChain v0.3.4 and will be removed in LangChain v1.0.0.""" |
| |
|
| | import logging |
| | from abc import ABC, abstractmethod |
| | from itertools import islice |
| | from typing import Any, Dict, Iterable, List, Optional |
| |
|
| | from langchain_core._api import deprecated |
| | from langchain_core.language_models import BaseLanguageModel |
| | from langchain_core.messages import BaseMessage, get_buffer_string |
| | from langchain_core.prompts import BasePromptTemplate |
| | from pydantic import BaseModel, ConfigDict, Field |
| |
|
| | from langchain.chains.llm import LLMChain |
| | from langchain.memory.chat_memory import BaseChatMemory |
| | from langchain.memory.prompt import ( |
| | ENTITY_EXTRACTION_PROMPT, |
| | ENTITY_SUMMARIZATION_PROMPT, |
| | ) |
| | from langchain.memory.utils import get_prompt_input_key |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @deprecated( |
| | since="0.3.1", |
| | removal="1.0.0", |
| | message=( |
| | "Please see the migration guide at: " |
| | "https://python.langchain.com/docs/versions/migrating_memory/" |
| | ), |
| | ) |
| | class BaseEntityStore(BaseModel, ABC): |
| | """Abstract base class for Entity store.""" |
| |
|
| | @abstractmethod |
| | def get(self, key: str, default: Optional[str] = None) -> Optional[str]: |
| | """Get entity value from store.""" |
| | pass |
| |
|
| | @abstractmethod |
| | def set(self, key: str, value: Optional[str]) -> None: |
| | """Set entity value in store.""" |
| | pass |
| |
|
| | @abstractmethod |
| | def delete(self, key: str) -> None: |
| | """Delete entity value from store.""" |
| | pass |
| |
|
| | @abstractmethod |
| | def exists(self, key: str) -> bool: |
| | """Check if entity exists in store.""" |
| | pass |
| |
|
| | @abstractmethod |
| | def clear(self) -> None: |
| | """Delete all entities from store.""" |
| | pass |
| |
|
| |
|
| | @deprecated( |
| | since="0.3.1", |
| | removal="1.0.0", |
| | message=( |
| | "Please see the migration guide at: " |
| | "https://python.langchain.com/docs/versions/migrating_memory/" |
| | ), |
| | ) |
| | class InMemoryEntityStore(BaseEntityStore): |
| | """In-memory Entity store.""" |
| |
|
| | store: Dict[str, Optional[str]] = {} |
| |
|
| | def get(self, key: str, default: Optional[str] = None) -> Optional[str]: |
| | return self.store.get(key, default) |
| |
|
| | def set(self, key: str, value: Optional[str]) -> None: |
| | self.store[key] = value |
| |
|
| | def delete(self, key: str) -> None: |
| | del self.store[key] |
| |
|
| | def exists(self, key: str) -> bool: |
| | return key in self.store |
| |
|
| | def clear(self) -> None: |
| | return self.store.clear() |
| |
|
| |
|
| | @deprecated( |
| | since="0.3.1", |
| | removal="1.0.0", |
| | message=( |
| | "Please see the migration guide at: " |
| | "https://python.langchain.com/docs/versions/migrating_memory/" |
| | ), |
| | ) |
| | class UpstashRedisEntityStore(BaseEntityStore): |
| | """Upstash Redis backed Entity store. |
| | |
| | Entities get a TTL of 1 day by default, and |
| | that TTL is extended by 3 days every time the entity is read back. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | session_id: str = "default", |
| | url: str = "", |
| | token: str = "", |
| | key_prefix: str = "memory_store", |
| | ttl: Optional[int] = 60 * 60 * 24, |
| | recall_ttl: Optional[int] = 60 * 60 * 24 * 3, |
| | *args: Any, |
| | **kwargs: Any, |
| | ): |
| | try: |
| | from upstash_redis import Redis |
| | except ImportError: |
| | raise ImportError( |
| | "Could not import upstash_redis python package. " |
| | "Please install it with `pip install upstash_redis`." |
| | ) |
| |
|
| | super().__init__(*args, **kwargs) |
| |
|
| | try: |
| | self.redis_client = Redis(url=url, token=token) |
| | except Exception: |
| | logger.error("Upstash Redis instance could not be initiated.") |
| |
|
| | self.session_id = session_id |
| | self.key_prefix = key_prefix |
| | self.ttl = ttl |
| | self.recall_ttl = recall_ttl or ttl |
| |
|
| | @property |
| | def full_key_prefix(self) -> str: |
| | return f"{self.key_prefix}:{self.session_id}" |
| |
|
| | def get(self, key: str, default: Optional[str] = None) -> Optional[str]: |
| | res = ( |
| | self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl) |
| | or default |
| | or "" |
| | ) |
| | logger.debug(f"Upstash Redis MEM get '{self.full_key_prefix}:{key}': '{res}'") |
| | return res |
| |
|
| | def set(self, key: str, value: Optional[str]) -> None: |
| | if not value: |
| | return self.delete(key) |
| | self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl) |
| | logger.debug( |
| | f"Redis MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}" |
| | ) |
| |
|
| | def delete(self, key: str) -> None: |
| | self.redis_client.delete(f"{self.full_key_prefix}:{key}") |
| |
|
| | def exists(self, key: str) -> bool: |
| | return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1 |
| |
|
| | def clear(self) -> None: |
| | def scan_and_delete(cursor: int) -> int: |
| | cursor, keys_to_delete = self.redis_client.scan( |
| | cursor, f"{self.full_key_prefix}:*" |
| | ) |
| | self.redis_client.delete(*keys_to_delete) |
| | return cursor |
| |
|
| | cursor = scan_and_delete(0) |
| | while cursor != 0: |
| | scan_and_delete(cursor) |
| |
|
| |
|
| | @deprecated( |
| | since="0.3.1", |
| | removal="1.0.0", |
| | message=( |
| | "Please see the migration guide at: " |
| | "https://python.langchain.com/docs/versions/migrating_memory/" |
| | ), |
| | ) |
| | class RedisEntityStore(BaseEntityStore): |
| | """Redis-backed Entity store. |
| | |
| | Entities get a TTL of 1 day by default, and |
| | that TTL is extended by 3 days every time the entity is read back. |
| | """ |
| |
|
| | redis_client: Any |
| | session_id: str = "default" |
| | key_prefix: str = "memory_store" |
| | ttl: Optional[int] = 60 * 60 * 24 |
| | recall_ttl: Optional[int] = 60 * 60 * 24 * 3 |
| |
|
| | def __init__( |
| | self, |
| | session_id: str = "default", |
| | url: str = "redis://localhost:6379/0", |
| | key_prefix: str = "memory_store", |
| | ttl: Optional[int] = 60 * 60 * 24, |
| | recall_ttl: Optional[int] = 60 * 60 * 24 * 3, |
| | *args: Any, |
| | **kwargs: Any, |
| | ): |
| | try: |
| | import redis |
| | except ImportError: |
| | raise ImportError( |
| | "Could not import redis python package. " |
| | "Please install it with `pip install redis`." |
| | ) |
| |
|
| | super().__init__(*args, **kwargs) |
| |
|
| | try: |
| | from langchain_community.utilities.redis import get_client |
| | except ImportError: |
| | raise ImportError( |
| | "Could not import langchain_community.utilities.redis.get_client. " |
| | "Please install it with `pip install langchain-community`." |
| | ) |
| |
|
| | try: |
| | self.redis_client = get_client(redis_url=url, decode_responses=True) |
| | except redis.exceptions.ConnectionError as error: |
| | logger.error(error) |
| |
|
| | self.session_id = session_id |
| | self.key_prefix = key_prefix |
| | self.ttl = ttl |
| | self.recall_ttl = recall_ttl or ttl |
| |
|
| | @property |
| | def full_key_prefix(self) -> str: |
| | return f"{self.key_prefix}:{self.session_id}" |
| |
|
| | def get(self, key: str, default: Optional[str] = None) -> Optional[str]: |
| | res = ( |
| | self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl) |
| | or default |
| | or "" |
| | ) |
| | logger.debug(f"REDIS MEM get '{self.full_key_prefix}:{key}': '{res}'") |
| | return res |
| |
|
| | def set(self, key: str, value: Optional[str]) -> None: |
| | if not value: |
| | return self.delete(key) |
| | self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl) |
| | logger.debug( |
| | f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}" |
| | ) |
| |
|
| | def delete(self, key: str) -> None: |
| | self.redis_client.delete(f"{self.full_key_prefix}:{key}") |
| |
|
| | def exists(self, key: str) -> bool: |
| | return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1 |
| |
|
| | def clear(self) -> None: |
| | |
| | def batched(iterable: Iterable[Any], batch_size: int) -> Iterable[Any]: |
| | iterator = iter(iterable) |
| | while batch := list(islice(iterator, batch_size)): |
| | yield batch |
| |
|
| | for keybatch in batched( |
| | self.redis_client.scan_iter(f"{self.full_key_prefix}:*"), 500 |
| | ): |
| | self.redis_client.delete(*keybatch) |
| |
|
| |
|
| | @deprecated( |
| | since="0.3.1", |
| | removal="1.0.0", |
| | message=( |
| | "Please see the migration guide at: " |
| | "https://python.langchain.com/docs/versions/migrating_memory/" |
| | ), |
| | ) |
| | class SQLiteEntityStore(BaseEntityStore): |
| | """SQLite-backed Entity store""" |
| |
|
| | session_id: str = "default" |
| | table_name: str = "memory_store" |
| | conn: Any = None |
| |
|
| | model_config = ConfigDict( |
| | arbitrary_types_allowed=True, |
| | ) |
| |
|
| | def __init__( |
| | self, |
| | session_id: str = "default", |
| | db_file: str = "entities.db", |
| | table_name: str = "memory_store", |
| | *args: Any, |
| | **kwargs: Any, |
| | ): |
| | try: |
| | import sqlite3 |
| | except ImportError: |
| | raise ImportError( |
| | "Could not import sqlite3 python package. " |
| | "Please install it with `pip install sqlite3`." |
| | ) |
| | super().__init__(*args, **kwargs) |
| |
|
| | self.conn = sqlite3.connect(db_file) |
| | self.session_id = session_id |
| | self.table_name = table_name |
| | self._create_table_if_not_exists() |
| |
|
| | @property |
| | def full_table_name(self) -> str: |
| | return f"{self.table_name}_{self.session_id}" |
| |
|
| | def _create_table_if_not_exists(self) -> None: |
| | create_table_query = f""" |
| | CREATE TABLE IF NOT EXISTS {self.full_table_name} ( |
| | key TEXT PRIMARY KEY, |
| | value TEXT |
| | ) |
| | """ |
| | with self.conn: |
| | self.conn.execute(create_table_query) |
| |
|
| | def get(self, key: str, default: Optional[str] = None) -> Optional[str]: |
| | query = f""" |
| | SELECT value |
| | FROM {self.full_table_name} |
| | WHERE key = ? |
| | """ |
| | cursor = self.conn.execute(query, (key,)) |
| | result = cursor.fetchone() |
| | if result is not None: |
| | value = result[0] |
| | return value |
| | return default |
| |
|
| | def set(self, key: str, value: Optional[str]) -> None: |
| | if not value: |
| | return self.delete(key) |
| | query = f""" |
| | INSERT OR REPLACE INTO {self.full_table_name} (key, value) |
| | VALUES (?, ?) |
| | """ |
| | with self.conn: |
| | self.conn.execute(query, (key, value)) |
| |
|
| | def delete(self, key: str) -> None: |
| | query = f""" |
| | DELETE FROM {self.full_table_name} |
| | WHERE key = ? |
| | """ |
| | with self.conn: |
| | self.conn.execute(query, (key,)) |
| |
|
| | def exists(self, key: str) -> bool: |
| | query = f""" |
| | SELECT 1 |
| | FROM {self.full_table_name} |
| | WHERE key = ? |
| | LIMIT 1 |
| | """ |
| | cursor = self.conn.execute(query, (key,)) |
| | result = cursor.fetchone() |
| | return result is not None |
| |
|
| | def clear(self) -> None: |
| | query = f""" |
| | DELETE FROM {self.full_table_name} |
| | """ |
| | with self.conn: |
| | self.conn.execute(query) |
| |
|
| |
|
| | @deprecated( |
| | since="0.3.1", |
| | removal="1.0.0", |
| | message=( |
| | "Please see the migration guide at: " |
| | "https://python.langchain.com/docs/versions/migrating_memory/" |
| | ), |
| | ) |
| | class ConversationEntityMemory(BaseChatMemory): |
| | """Entity extractor & summarizer memory. |
| | |
| | Extracts named entities from the recent chat history and generates summaries. |
| | With a swappable entity store, persisting entities across conversations. |
| | Defaults to an in-memory entity store, and can be swapped out for a Redis, |
| | SQLite, or other entity store. |
| | """ |
| |
|
| | human_prefix: str = "Human" |
| | ai_prefix: str = "AI" |
| | llm: BaseLanguageModel |
| | entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT |
| | entity_summarization_prompt: BasePromptTemplate = ENTITY_SUMMARIZATION_PROMPT |
| |
|
| | |
| | |
| | entity_cache: List[str] = [] |
| |
|
| | |
| | k: int = 3 |
| |
|
| | chat_history_key: str = "history" |
| |
|
| | |
| | entity_store: BaseEntityStore = Field(default_factory=InMemoryEntityStore) |
| |
|
| | @property |
| | def buffer(self) -> List[BaseMessage]: |
| | """Access chat memory messages.""" |
| | return self.chat_memory.messages |
| |
|
| | @property |
| | def memory_variables(self) -> List[str]: |
| | """Will always return list of memory variables. |
| | |
| | :meta private: |
| | """ |
| | return ["entities", self.chat_history_key] |
| |
|
| | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: |
| | """ |
| | Returns chat history and all generated entities with summaries if available, |
| | and updates or clears the recent entity cache. |
| | |
| | New entity name can be found when calling this method, before the entity |
| | summaries are generated, so the entity cache values may be empty if no entity |
| | descriptions are generated yet. |
| | """ |
| |
|
| | |
| | chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt) |
| |
|
| | if self.input_key is None: |
| | prompt_input_key = get_prompt_input_key(inputs, self.memory_variables) |
| | else: |
| | prompt_input_key = self.input_key |
| |
|
| | |
| | |
| | |
| | buffer_string = get_buffer_string( |
| | self.buffer[-self.k * 2 :], |
| | human_prefix=self.human_prefix, |
| | ai_prefix=self.ai_prefix, |
| | ) |
| |
|
| | |
| | |
| | |
| | output = chain.predict( |
| | history=buffer_string, |
| | input=inputs[prompt_input_key], |
| | ) |
| |
|
| | |
| | if output.strip() == "NONE": |
| | entities = [] |
| | else: |
| | |
| | entities = [w.strip() for w in output.split(",")] |
| |
|
| | |
| | entity_summaries = {} |
| |
|
| | for entity in entities: |
| | entity_summaries[entity] = self.entity_store.get(entity, "") |
| |
|
| | |
| | |
| | self.entity_cache = entities |
| |
|
| | |
| | if self.return_messages: |
| | |
| | buffer: Any = self.buffer[-self.k * 2 :] |
| | else: |
| | |
| | buffer = buffer_string |
| |
|
| | return { |
| | self.chat_history_key: buffer, |
| | "entities": entity_summaries, |
| | } |
| |
|
| | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: |
| | """ |
| | Save context from this conversation history to the entity store. |
| | |
| | Generates a summary for each entity in the entity cache by prompting |
| | the model, and saves these summaries to the entity store. |
| | """ |
| |
|
| | super().save_context(inputs, outputs) |
| |
|
| | if self.input_key is None: |
| | prompt_input_key = get_prompt_input_key(inputs, self.memory_variables) |
| | else: |
| | prompt_input_key = self.input_key |
| |
|
| | |
| | |
| | |
| | buffer_string = get_buffer_string( |
| | self.buffer[-self.k * 2 :], |
| | human_prefix=self.human_prefix, |
| | ai_prefix=self.ai_prefix, |
| | ) |
| |
|
| | input_data = inputs[prompt_input_key] |
| |
|
| | |
| | chain = LLMChain(llm=self.llm, prompt=self.entity_summarization_prompt) |
| |
|
| | |
| | for entity in self.entity_cache: |
| | |
| | existing_summary = self.entity_store.get(entity, "") |
| | output = chain.predict( |
| | summary=existing_summary, |
| | entity=entity, |
| | history=buffer_string, |
| | input=input_data, |
| | ) |
| | |
| | self.entity_store.set(entity, output.strip()) |
| |
|
| | def clear(self) -> None: |
| | """Clear memory contents.""" |
| | self.chat_memory.clear() |
| | self.entity_cache.clear() |
| | self.entity_store.clear() |
| |
|