| from enum import Enum |
| from pydantic import Field |
| from typing import Union, Optional, List |
| from ..core.module import BaseModule |
| from ..core.message import Message, MessageType |
| from ..models.base_model import LLMOutputParser |
|
|
|
|
| class TrajectoryState(str, Enum): |
| """ |
| Enum representing the status of a trajectory step. |
| """ |
| COMPLETED = "COMPLETED" |
| FAILED = "FAILED" |
|
|
|
|
| class TrajectoryStep(BaseModule): |
|
|
| message: Message = None |
| status: TrajectoryState |
| error: Optional[str] = None |
|
|
|
|
| class Environment(BaseModule): |
|
|
| """ |
| Responsible for storing and managing intermediate states of execution. |
| """ |
| trajectory: List[TrajectoryStep] = Field(default_factory=list) |
| task_execution_history: List[str] = Field(default_factory=list) |
| execution_data: dict = Field(default_factory=dict) |
|
|
| def update(self, message: Message, state: TrajectoryState = None, error: str = None, **kwargs): |
| """ |
| Add a message to the shared memory and optionally to a specific task's message list. |
| |
| Args: |
| message (Message): The message to be added. |
| task_name (str, optional): The name of the task this message is related to. If None, the message is considered global. |
| """ |
| state = state or TrajectoryState.COMPLETED |
| step = TrajectoryStep(message=message, status=state, error=error) |
| self.trajectory.append(step) |
| self.update_task_execution_history(message=message) |
| self.update_execution_data(message=message) |
| |
| def update_task_execution_history(self, message: Message): |
| if message.wf_task is not None and message.msg_type in [MessageType.RESPONSE]: |
| |
| if not self.task_execution_history or message.wf_task != self.task_execution_history[-1]: |
| self.task_execution_history.append(message.wf_task) |
|
|
| def update_execution_data(self, message: Message): |
| if isinstance(message.content, LLMOutputParser): |
| data = message.content.get_structured_data() |
| self.execution_data.update(data) |
| if isinstance(message.content, dict): |
| data = message.content |
| self.execution_data.update(data) |
| |
| def update_execution_data_from_context_extraction(self, extracted_data: dict): |
| for key, value in extracted_data.items(): |
| if key not in self.execution_data: |
| self.execution_data[key] = value |
| |
| def get_task_messages(self, tasks: Union[str, List[str]], n: int = None, include_inputs: bool = False, **kwargs) -> List[Message]: |
| """ |
| Retrieve all messages related to specified tasks |
| |
| Returns: |
| List[Message]: A list of messages related to the task. |
| """ |
| if isinstance(tasks, str): |
| tasks = [tasks] |
| message_list = [] |
| for step in self.trajectory: |
| message = step.message |
| if message.wf_task is not None and message.wf_task in tasks: |
| message_list.append(message) |
| if include_inputs and message.msg_type == MessageType.INPUT and message not in message_list: |
| message_list.append(message) |
| message_list = message_list if n is None else message_list[-n:] |
| return message_list |
|
|
| def get(self, n: int=None) -> List[Message]: |
| """ |
| return the most recent n messages |
| """ |
| assert n is None or n>=0, "n must be None or a positive int" |
| all_messages = [step.message for step in self.trajectory] |
| messages = all_messages if n is None else all_messages[-n:] |
| return messages |
| |
| def get_last_executed_task(self) -> str: |
| if self.task_execution_history: |
| return self.task_execution_history[-1] |
| return None |
| |
| def get_all_execution_data(self) -> dict: |
| return self.execution_data |
| |
| def get_execution_data(self, params: Union[str, List[str]]) -> dict: |
| if isinstance(params, str): |
| params = [params] |
| data = {} |
| for param in params: |
| if param not in self.execution_data: |
| raise KeyError(f"Couldn't find execution data with key '{param}'. Available execution data: {list(self.execution_data.keys())}") |
| data[param] = self.execution_data[param] |
| return data |
|
|
|
|