| import json |
| import logging |
| import uuid |
| from collections.abc import Mapping, Sequence |
| from datetime import datetime, timezone |
| from typing import Optional, Union, cast |
|
|
| from core.agent.entities import AgentEntity, AgentToolEntity |
| from core.app.app_config.features.file_upload.manager import FileUploadConfigManager |
| from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig |
| from core.app.apps.base_app_queue_manager import AppQueueManager |
| from core.app.apps.base_app_runner import AppRunner |
| from core.app.entities.app_invoke_entities import ( |
| AgentChatAppGenerateEntity, |
| ModelConfigWithCredentialsEntity, |
| ) |
| from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler |
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler |
| from core.file import file_manager |
| from core.memory.token_buffer_memory import TokenBufferMemory |
| from core.model_manager import ModelInstance |
| from core.model_runtime.entities import ( |
| AssistantPromptMessage, |
| LLMUsage, |
| PromptMessage, |
| PromptMessageContent, |
| PromptMessageTool, |
| SystemPromptMessage, |
| TextPromptMessageContent, |
| ToolPromptMessage, |
| UserPromptMessage, |
| ) |
| from core.model_runtime.entities.model_entities import ModelFeature |
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel |
| from core.model_runtime.utils.encoders import jsonable_encoder |
| from core.prompt.utils.extract_thread_messages import extract_thread_messages |
| from core.tools.entities.tool_entities import ( |
| ToolParameter, |
| ToolRuntimeVariablePool, |
| ) |
| from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool |
| from core.tools.tool.tool import Tool |
| from core.tools.tool_manager import ToolManager |
| from extensions.ext_database import db |
| from factories import file_factory |
| from models.model import Conversation, Message, MessageAgentThought, MessageFile |
| from models.tools import ToolConversationVariables |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class BaseAgentRunner(AppRunner): |
| def __init__( |
| self, |
| tenant_id: str, |
| application_generate_entity: AgentChatAppGenerateEntity, |
| conversation: Conversation, |
| app_config: AgentChatAppConfig, |
| model_config: ModelConfigWithCredentialsEntity, |
| config: AgentEntity, |
| queue_manager: AppQueueManager, |
| message: Message, |
| user_id: str, |
| memory: Optional[TokenBufferMemory] = None, |
| prompt_messages: Optional[list[PromptMessage]] = None, |
| variables_pool: Optional[ToolRuntimeVariablePool] = None, |
| db_variables: Optional[ToolConversationVariables] = None, |
| model_instance: ModelInstance = None, |
| ) -> None: |
| self.tenant_id = tenant_id |
| self.application_generate_entity = application_generate_entity |
| self.conversation = conversation |
| self.app_config = app_config |
| self.model_config = model_config |
| self.config = config |
| self.queue_manager = queue_manager |
| self.message = message |
| self.user_id = user_id |
| self.memory = memory |
| self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or []) |
| self.variables_pool = variables_pool |
| self.db_variables_pool = db_variables |
| self.model_instance = model_instance |
|
|
| |
| self.agent_callback = DifyAgentCallbackHandler() |
| |
| hit_callback = DatasetIndexToolCallbackHandler( |
| queue_manager=queue_manager, |
| app_id=self.app_config.app_id, |
| message_id=message.id, |
| user_id=user_id, |
| invoke_from=self.application_generate_entity.invoke_from, |
| ) |
| self.dataset_tools = DatasetRetrieverTool.get_dataset_tools( |
| tenant_id=tenant_id, |
| dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [], |
| retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None, |
| return_resource=app_config.additional_features.show_retrieve_source, |
| invoke_from=application_generate_entity.invoke_from, |
| hit_callback=hit_callback, |
| ) |
| |
| self.agent_thought_count = ( |
| db.session.query(MessageAgentThought) |
| .filter( |
| MessageAgentThought.message_id == self.message.id, |
| ) |
| .count() |
| ) |
| db.session.close() |
|
|
| |
| llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) |
| model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) |
| if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []): |
| self.stream_tool_call = True |
| else: |
| self.stream_tool_call = False |
|
|
| |
| if model_schema and ModelFeature.VISION in (model_schema.features or []): |
| self.files = application_generate_entity.files |
| else: |
| self.files = [] |
| self.query = None |
| self._current_thoughts: list[PromptMessage] = [] |
|
|
| def _repack_app_generate_entity( |
| self, app_generate_entity: AgentChatAppGenerateEntity |
| ) -> AgentChatAppGenerateEntity: |
| """ |
| Repack app generate entity |
| """ |
| if app_generate_entity.app_config.prompt_template.simple_prompt_template is None: |
| app_generate_entity.app_config.prompt_template.simple_prompt_template = "" |
|
|
| return app_generate_entity |
|
|
| def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]: |
| """ |
| convert tool to prompt message tool |
| """ |
| tool_entity = ToolManager.get_agent_tool_runtime( |
| tenant_id=self.tenant_id, |
| app_id=self.app_config.app_id, |
| agent_tool=tool, |
| invoke_from=self.application_generate_entity.invoke_from, |
| ) |
| tool_entity.load_variables(self.variables_pool) |
|
|
| message_tool = PromptMessageTool( |
| name=tool.tool_name, |
| description=tool_entity.description.llm, |
| parameters={ |
| "type": "object", |
| "properties": {}, |
| "required": [], |
| }, |
| ) |
|
|
| parameters = tool_entity.get_all_runtime_parameters() |
| for parameter in parameters: |
| if parameter.form != ToolParameter.ToolParameterForm.LLM: |
| continue |
|
|
| parameter_type = parameter.type.as_normal_type() |
| if parameter.type in { |
| ToolParameter.ToolParameterType.SYSTEM_FILES, |
| ToolParameter.ToolParameterType.FILE, |
| ToolParameter.ToolParameterType.FILES, |
| }: |
| continue |
| enum = [] |
| if parameter.type == ToolParameter.ToolParameterType.SELECT: |
| enum = [option.value for option in parameter.options] |
|
|
| message_tool.parameters["properties"][parameter.name] = { |
| "type": parameter_type, |
| "description": parameter.llm_description or "", |
| } |
|
|
| if len(enum) > 0: |
| message_tool.parameters["properties"][parameter.name]["enum"] = enum |
|
|
| if parameter.required: |
| message_tool.parameters["required"].append(parameter.name) |
|
|
| return message_tool, tool_entity |
|
|
| def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool: |
| """ |
| convert dataset retriever tool to prompt message tool |
| """ |
| prompt_tool = PromptMessageTool( |
| name=tool.identity.name, |
| description=tool.description.llm, |
| parameters={ |
| "type": "object", |
| "properties": {}, |
| "required": [], |
| }, |
| ) |
|
|
| for parameter in tool.get_runtime_parameters(): |
| parameter_type = "string" |
|
|
| prompt_tool.parameters["properties"][parameter.name] = { |
| "type": parameter_type, |
| "description": parameter.llm_description or "", |
| } |
|
|
| if parameter.required: |
| if parameter.name not in prompt_tool.parameters["required"]: |
| prompt_tool.parameters["required"].append(parameter.name) |
|
|
| return prompt_tool |
|
|
| def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]: |
| """ |
| Init tools |
| """ |
| tool_instances = {} |
| prompt_messages_tools = [] |
|
|
| for tool in self.app_config.agent.tools if self.app_config.agent else []: |
| try: |
| prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) |
| except Exception: |
| |
| continue |
| |
| tool_instances[tool.tool_name] = tool_entity |
| |
| prompt_messages_tools.append(prompt_tool) |
|
|
| |
| for dataset_tool in self.dataset_tools: |
| prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool) |
| |
| prompt_messages_tools.append(prompt_tool) |
| |
| tool_instances[dataset_tool.identity.name] = dataset_tool |
|
|
| return tool_instances, prompt_messages_tools |
|
|
| def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool: |
| """ |
| update prompt message tool |
| """ |
| |
| tool_runtime_parameters = tool.get_runtime_parameters() or [] |
|
|
| for parameter in tool_runtime_parameters: |
| if parameter.form != ToolParameter.ToolParameterForm.LLM: |
| continue |
|
|
| parameter_type = parameter.type.as_normal_type() |
| if parameter.type in { |
| ToolParameter.ToolParameterType.SYSTEM_FILES, |
| ToolParameter.ToolParameterType.FILE, |
| ToolParameter.ToolParameterType.FILES, |
| }: |
| continue |
| enum = [] |
| if parameter.type == ToolParameter.ToolParameterType.SELECT: |
| enum = [option.value for option in parameter.options] |
|
|
| prompt_tool.parameters["properties"][parameter.name] = { |
| "type": parameter_type, |
| "description": parameter.llm_description or "", |
| } |
|
|
| if len(enum) > 0: |
| prompt_tool.parameters["properties"][parameter.name]["enum"] = enum |
|
|
| if parameter.required: |
| if parameter.name not in prompt_tool.parameters["required"]: |
| prompt_tool.parameters["required"].append(parameter.name) |
|
|
| return prompt_tool |
|
|
| def create_agent_thought( |
| self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str] |
| ) -> MessageAgentThought: |
| """ |
| Create agent thought |
| """ |
| thought = MessageAgentThought( |
| message_id=message_id, |
| message_chain_id=None, |
| thought="", |
| tool=tool_name, |
| tool_labels_str="{}", |
| tool_meta_str="{}", |
| tool_input=tool_input, |
| message=message, |
| message_token=0, |
| message_unit_price=0, |
| message_price_unit=0, |
| message_files=json.dumps(messages_ids) if messages_ids else "", |
| answer="", |
| observation="", |
| answer_token=0, |
| answer_unit_price=0, |
| answer_price_unit=0, |
| tokens=0, |
| total_price=0, |
| position=self.agent_thought_count + 1, |
| currency="USD", |
| latency=0, |
| created_by_role="account", |
| created_by=self.user_id, |
| ) |
|
|
| db.session.add(thought) |
| db.session.commit() |
| db.session.refresh(thought) |
| db.session.close() |
|
|
| self.agent_thought_count += 1 |
|
|
| return thought |
|
|
| def save_agent_thought( |
| self, |
| agent_thought: MessageAgentThought, |
| tool_name: str, |
| tool_input: Union[str, dict], |
| thought: str, |
| observation: Union[str, dict], |
| tool_invoke_meta: Union[str, dict], |
| answer: str, |
| messages_ids: list[str], |
| llm_usage: LLMUsage = None, |
| ) -> MessageAgentThought: |
| """ |
| Save agent thought |
| """ |
| agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() |
|
|
| if thought is not None: |
| agent_thought.thought = thought |
|
|
| if tool_name is not None: |
| agent_thought.tool = tool_name |
|
|
| if tool_input is not None: |
| if isinstance(tool_input, dict): |
| try: |
| tool_input = json.dumps(tool_input, ensure_ascii=False) |
| except Exception as e: |
| tool_input = json.dumps(tool_input) |
|
|
| agent_thought.tool_input = tool_input |
|
|
| if observation is not None: |
| if isinstance(observation, dict): |
| try: |
| observation = json.dumps(observation, ensure_ascii=False) |
| except Exception as e: |
| observation = json.dumps(observation) |
|
|
| agent_thought.observation = observation |
|
|
| if answer is not None: |
| agent_thought.answer = answer |
|
|
| if messages_ids is not None and len(messages_ids) > 0: |
| agent_thought.message_files = json.dumps(messages_ids) |
|
|
| if llm_usage: |
| agent_thought.message_token = llm_usage.prompt_tokens |
| agent_thought.message_price_unit = llm_usage.prompt_price_unit |
| agent_thought.message_unit_price = llm_usage.prompt_unit_price |
| agent_thought.answer_token = llm_usage.completion_tokens |
| agent_thought.answer_price_unit = llm_usage.completion_price_unit |
| agent_thought.answer_unit_price = llm_usage.completion_unit_price |
| agent_thought.tokens = llm_usage.total_tokens |
| agent_thought.total_price = llm_usage.total_price |
|
|
| |
| labels = agent_thought.tool_labels or {} |
| tools = agent_thought.tool.split(";") if agent_thought.tool else [] |
| for tool in tools: |
| if not tool: |
| continue |
| if tool not in labels: |
| tool_label = ToolManager.get_tool_label(tool) |
| if tool_label: |
| labels[tool] = tool_label.to_dict() |
| else: |
| labels[tool] = {"en_US": tool, "zh_Hans": tool} |
|
|
| agent_thought.tool_labels_str = json.dumps(labels) |
|
|
| if tool_invoke_meta is not None: |
| if isinstance(tool_invoke_meta, dict): |
| try: |
| tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False) |
| except Exception as e: |
| tool_invoke_meta = json.dumps(tool_invoke_meta) |
|
|
| agent_thought.tool_meta_str = tool_invoke_meta |
|
|
| db.session.commit() |
| db.session.close() |
|
|
| def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables): |
| """ |
| convert tool variables to db variables |
| """ |
| db_variables = ( |
| db.session.query(ToolConversationVariables) |
| .filter( |
| ToolConversationVariables.conversation_id == self.message.conversation_id, |
| ) |
| .first() |
| ) |
|
|
| db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) |
| db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) |
| db.session.commit() |
| db.session.close() |
|
|
| def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: |
| """ |
| Organize agent history |
| """ |
| result = [] |
| |
| for prompt_message in prompt_messages: |
| if isinstance(prompt_message, SystemPromptMessage): |
| result.append(prompt_message) |
|
|
| messages: list[Message] = ( |
| db.session.query(Message) |
| .filter( |
| Message.conversation_id == self.message.conversation_id, |
| ) |
| .order_by(Message.created_at.desc()) |
| .all() |
| ) |
|
|
| messages = list(reversed(extract_thread_messages(messages))) |
|
|
| for message in messages: |
| if message.id == self.message.id: |
| continue |
|
|
| result.append(self.organize_agent_user_prompt(message)) |
| agent_thoughts: list[MessageAgentThought] = message.agent_thoughts |
| if agent_thoughts: |
| for agent_thought in agent_thoughts: |
| tools = agent_thought.tool |
| if tools: |
| tools = tools.split(";") |
| tool_calls: list[AssistantPromptMessage.ToolCall] = [] |
| tool_call_response: list[ToolPromptMessage] = [] |
| try: |
| tool_inputs = json.loads(agent_thought.tool_input) |
| except Exception as e: |
| tool_inputs = {tool: {} for tool in tools} |
| try: |
| tool_responses = json.loads(agent_thought.observation) |
| except Exception as e: |
| tool_responses = dict.fromkeys(tools, agent_thought.observation) |
|
|
| for tool in tools: |
| |
| tool_call_id = str(uuid.uuid4()) |
| tool_calls.append( |
| AssistantPromptMessage.ToolCall( |
| id=tool_call_id, |
| type="function", |
| function=AssistantPromptMessage.ToolCall.ToolCallFunction( |
| name=tool, |
| arguments=json.dumps(tool_inputs.get(tool, {})), |
| ), |
| ) |
| ) |
| tool_call_response.append( |
| ToolPromptMessage( |
| content=tool_responses.get(tool, agent_thought.observation), |
| name=tool, |
| tool_call_id=tool_call_id, |
| ) |
| ) |
|
|
| result.extend( |
| [ |
| AssistantPromptMessage( |
| content=agent_thought.thought, |
| tool_calls=tool_calls, |
| ), |
| *tool_call_response, |
| ] |
| ) |
| if not tools: |
| result.append(AssistantPromptMessage(content=agent_thought.thought)) |
| else: |
| if message.answer: |
| result.append(AssistantPromptMessage(content=message.answer)) |
|
|
| db.session.close() |
|
|
| return result |
|
|
| def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: |
| files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() |
| if files: |
| file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) |
|
|
| if file_extra_config: |
| file_objs = file_factory.build_from_message_files( |
| message_files=files, tenant_id=self.tenant_id, config=file_extra_config |
| ) |
| else: |
| file_objs = [] |
|
|
| if not file_objs: |
| return UserPromptMessage(content=message.query) |
| else: |
| prompt_message_contents: list[PromptMessageContent] = [] |
| prompt_message_contents.append(TextPromptMessageContent(data=message.query)) |
| for file_obj in file_objs: |
| prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) |
|
|
| return UserPromptMessage(content=prompt_message_contents) |
| else: |
| return UserPromptMessage(content=message.query) |
|
|