| | from enum import Enum |
| | from pydantic import Field, model_validator |
| | from datetime import datetime |
| | from typing import Optional, Callable, Any, List, Union |
| |
|
| | from .module import BaseModule |
| | from .module_utils import generate_id, get_timestamp |
| |
|
| | class MessageType(Enum): |
| | |
| | REQUEST = "request" |
| | RESPONSE = "response" |
| | COMMAND = "command" |
| | ERROR = "error" |
| | UNKNOWN = "unknown" |
| | INPUT = "input" |
| |
|
| |
|
| | class Message(BaseModule): |
| |
|
| | """ |
| | the base class for message. |
| | |
| | Attributes: |
| | content (Any): the content of the message, need to implement str() function. |
| | agent (str): the sender of the message, normally set as the agent name. |
| | action (str): the trigger of the message, normally set as the action name. |
| | prompt (str): the prompt used to obtain the generated text. |
| | next_actions (List[str]): the following actions. |
| | msg_type (str): the type of the message, such as "request", "response", "command" etc. |
| | wf_goal (str): the goal of the whole workflow. |
| | wf_task (str): the name of a task in the workflow, i.e., the ``name`` of a WorkFlowNode instance. |
| | wf_task_desc (str): the description of a task in the workflow, i.e., the ``description`` of a WorkFlowNode instance. |
| | message_id (str): the unique identifier of the message. |
| | timestamp (str): the timestame of the message. |
| | """ |
| | |
| | content: Any |
| | agent: Optional[str] = None |
| | |
| | action: Optional[str] = None |
| | prompt: Optional[Union[str, List[dict]]] = None |
| | next_actions: Optional[List[str]] = None |
| | msg_type: Optional[MessageType] = MessageType.UNKNOWN |
| | wf_goal: Optional[str] = None |
| | wf_task: Optional[str] = None |
| | wf_task_desc: Optional[str] = None |
| | message_id: Optional[str] = Field(default_factory=generate_id) |
| | timestamp: Optional[str] = Field(default_factory=get_timestamp) |
| | conversation_id: Optional[str] = Field(default_factory=generate_id) |
| | |
| | def __str__(self) -> str: |
| | return self.to_str() |
| | |
| | def __eq__(self, other: "Message"): |
| | return self.message_id == other.message_id |
| |
|
| | def __hash__(self): |
| | return self.message_id |
| | |
| | def to_str(self) -> str: |
| |
|
| | msg_part = [] |
| | if self.timestamp: |
| | msg_part.append(f"[{self.timestamp}]") |
| | if self.agent: |
| | msg_part.append(f"Agent: {self.agent}") |
| | if self.msg_type and self.msg_type != MessageType.UNKNOWN: |
| | msg_part.append(f"Type: {self.msg_type}") |
| | if self.action: |
| | msg_part.append(f"Action: {self.action}") |
| | if self.wf_goal: |
| | msg_part.append(f"Goal: {self.wf_goal}") |
| | if self.wf_task: |
| | msg_part.append(f"Task: {self.wf_task} ({self.wf_task_desc or 'No description'})") |
| | if self.content: |
| | msg_part.append(f"Content: {str(self.content)}") |
| | |
| | msg = "\n".join(msg_part) |
| | return msg |
| |
|
| | def to_dict(self, exclude_none: bool = True, ignore: List[str] = [], **kwargs) -> dict: |
| | """ |
| | Convert the Message to a dictionary for saving. |
| | """ |
| | data = super().to_dict(exclude_none=exclude_none, ignore=ignore, **kwargs) |
| | if self.msg_type: |
| | data["msg_type"] = self.msg_type.value |
| | return data |
| | |
| | @model_validator(mode="before") |
| | @classmethod |
| | def validate_data(cls, data: Any) -> Any: |
| | if "msg_type" in data and data["msg_type"] and isinstance(data["msg_type"], str): |
| | data["msg_type"] = MessageType(data["msg_type"]) |
| | return data |
| |
|
| | @classmethod |
| | def sort_by_timestamp(cls, messages: List['Message'], reverse: bool = False) -> List['Message']: |
| | """ |
| | sort the messages based on the timestamp. |
| | |
| | Args: |
| | messages (List[Message]): the messages to be sorted. |
| | reverse (bool): If True, sort the messages in descending order. Otherwise, sort the messages in ascending order. |
| | """ |
| | messages.sort(key=lambda msg: datetime.strptime(msg.timestamp, "%Y-%m-%d %H:%M:%S"), reverse=reverse) |
| | return messages |
| |
|
| | @classmethod |
| | def sort(cls, messages: List['Message'], key: Optional[Callable[['Message'], Any]] = None, reverse: bool = False) -> List['Message']: |
| | """ |
| | sort the messages using key or timestamp (by default). |
| | |
| | Args: |
| | messages (List[Message]): the messages to be sorted. |
| | key (Optional[Callable[['Message'], Any]]): the function used to sort messages. |
| | reverse (bool): If True, sort the messages in descending order. Otherwise, sort the messages in ascending order. |
| | """ |
| | if key is None: |
| | return cls.sort_by_timestamp(messages, reverse=reverse) |
| | messages.sort(key=key, reverse=reverse) |
| | return messages |
| |
|
| | @classmethod |
| | def merge(cls, messages: List[List['Message']], sort: bool=False, key: Optional[Callable[['Message'], Any]] = None, reverse: bool=False) -> List['Message']: |
| | """ |
| | merge different message list. |
| | |
| | Args: |
| | messages (List[List[Message]]): the message lists to be merged. |
| | sort (bool): whether to sort the merged messages. |
| | key (Optional[Callable[['Message'], Any]]): the function used to sort messages. |
| | reverse (bool): If True, sort the messages in descending order. Otherwise, sort the messages in ascending order. |
| | """ |
| | merged_messages = sum(messages, []) |
| | if sort: |
| | merged_messages = cls.sort(merged_messages, key=key, reverse=reverse) |
| | return merged_messages |
| | |
| |
|
| |
|