| from abc import ABC, abstractmethod |
| from typing import Callable, List, Tuple |
|
|
| from langchain_core.language_models import BaseLanguageModel |
| from langchain_core.language_models.chat_models import BaseChatModel |
| from langchain_core.language_models.llms import BaseLLM |
| from langchain_core.prompts import BasePromptTemplate |
| from pydantic import BaseModel, Field |
|
|
|
|
| class BasePromptSelector(BaseModel, ABC): |
| """Base class for prompt selectors.""" |
|
|
| @abstractmethod |
| def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate: |
| """Get default prompt for a language model.""" |
|
|
|
|
| class ConditionalPromptSelector(BasePromptSelector): |
| """Prompt collection that goes through conditionals.""" |
|
|
| default_prompt: BasePromptTemplate |
| """Default prompt to use if no conditionals match.""" |
| conditionals: List[ |
| Tuple[Callable[[BaseLanguageModel], bool], BasePromptTemplate] |
| ] = Field(default_factory=list) |
| """List of conditionals and prompts to use if the conditionals match.""" |
|
|
| def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate: |
| """Get default prompt for a language model. |
| |
| Args: |
| llm: Language model to get prompt for. |
| |
| Returns: |
| Prompt to use for the language model. |
| """ |
| for condition, prompt in self.conditionals: |
| if condition(llm): |
| return prompt |
| return self.default_prompt |
|
|
|
|
| def is_llm(llm: BaseLanguageModel) -> bool: |
| """Check if the language model is a LLM. |
| |
| Args: |
| llm: Language model to check. |
| |
| Returns: |
| True if the language model is a BaseLLM model, False otherwise. |
| """ |
| return isinstance(llm, BaseLLM) |
|
|
|
|
| def is_chat_model(llm: BaseLanguageModel) -> bool: |
| """Check if the language model is a chat model. |
| |
| Args: |
| llm: Language model to check. |
| |
| Returns: |
| True if the language model is a BaseChatModel model, False otherwise. |
| """ |
| return isinstance(llm, BaseChatModel) |
|
|