| | """ |
| | Base Agent for SPARKNET |
| | Defines the core agent interface and functionality |
| | """ |
| |
|
| | from abc import ABC, abstractmethod |
| | from typing import List, Dict, Optional, Any |
| | from dataclasses import dataclass |
| | from datetime import datetime |
| | from loguru import logger |
| | import json |
| |
|
| | from ..llm.ollama_client import OllamaClient |
| | from ..tools.base_tool import BaseTool, ToolRegistry, ToolResult |
| |
|
| |
|
| | @dataclass |
| | class Message: |
| | """Message for agent communication.""" |
| | role: str |
| | content: str |
| | sender: Optional[str] = None |
| | timestamp: Optional[datetime] = None |
| | metadata: Optional[Dict[str, Any]] = None |
| |
|
| | def __post_init__(self): |
| | if self.timestamp is None: |
| | self.timestamp = datetime.now() |
| |
|
| | def to_dict(self) -> Dict[str, str]: |
| | """Convert to dictionary for Ollama API.""" |
| | return { |
| | "role": "user" if self.role == "agent" else self.role, |
| | "content": self.content, |
| | } |
| |
|
| |
|
| | @dataclass |
| | class Task: |
| | """Task for agent execution.""" |
| | id: str |
| | description: str |
| | priority: int = 0 |
| | status: str = "pending" |
| | result: Optional[Any] = None |
| | error: Optional[str] = None |
| | metadata: Optional[Dict[str, Any]] = None |
| |
|
| | def __post_init__(self): |
| | if self.metadata is None: |
| | self.metadata = {} |
| |
|
| |
|
| | class BaseAgent(ABC): |
| | """Base class for all SPARKNET agents.""" |
| |
|
| | def __init__( |
| | self, |
| | name: str, |
| | description: str, |
| | llm_client: OllamaClient, |
| | model: str, |
| | system_prompt: str, |
| | tools: Optional[List[BaseTool]] = None, |
| | temperature: float = 0.7, |
| | max_tokens: Optional[int] = None, |
| | ): |
| | """ |
| | Initialize agent. |
| | |
| | Args: |
| | name: Agent name |
| | description: Agent description |
| | llm_client: Ollama client instance |
| | model: Model to use |
| | system_prompt: System prompt for the agent |
| | tools: List of available tools |
| | temperature: LLM temperature |
| | max_tokens: Max tokens to generate |
| | """ |
| | self.name = name |
| | self.description = description |
| | self.llm_client = llm_client |
| | self.model = model |
| | self.system_prompt = system_prompt |
| | self.tools = {tool.name: tool for tool in (tools or [])} |
| | self.temperature = temperature |
| | self.max_tokens = max_tokens |
| |
|
| | |
| | self.messages: List[Message] = [] |
| |
|
| | |
| | self.tool_registry: Optional[ToolRegistry] = None |
| |
|
| | logger.info(f"Initialized agent: {self.name} with model {self.model}") |
| |
|
| | def add_tool(self, tool: BaseTool): |
| | """ |
| | Add a tool to the agent's toolbox. |
| | |
| | Args: |
| | tool: Tool to add |
| | """ |
| | self.tools[tool.name] = tool |
| | logger.info(f"Agent {self.name} added tool: {tool.name}") |
| |
|
| | def remove_tool(self, tool_name: str): |
| | """ |
| | Remove a tool from the agent's toolbox. |
| | |
| | Args: |
| | tool_name: Name of tool to remove |
| | """ |
| | if tool_name in self.tools: |
| | del self.tools[tool_name] |
| | logger.info(f"Agent {self.name} removed tool: {tool_name}") |
| |
|
| | def set_tool_registry(self, registry: ToolRegistry): |
| | """ |
| | Set the tool registry for accessing shared tools. |
| | |
| | Args: |
| | registry: Tool registry instance |
| | """ |
| | self.tool_registry = registry |
| |
|
| | async def call_llm( |
| | self, |
| | prompt: Optional[str] = None, |
| | messages: Optional[List[Message]] = None, |
| | temperature: Optional[float] = None, |
| | ) -> str: |
| | """ |
| | Call the LLM with a prompt or messages. |
| | |
| | Args: |
| | prompt: Single prompt string |
| | messages: List of messages |
| | temperature: Override temperature |
| | |
| | Returns: |
| | LLM response |
| | """ |
| | temp = temperature if temperature is not None else self.temperature |
| |
|
| | if prompt: |
| | |
| | response = self.llm_client.generate( |
| | prompt=prompt, |
| | model=self.model, |
| | system=self.system_prompt, |
| | temperature=temp, |
| | max_tokens=self.max_tokens, |
| | ) |
| | elif messages: |
| | |
| | |
| | chat_messages = [ |
| | {"role": "system", "content": self.system_prompt} |
| | ] |
| | |
| | chat_messages.extend([msg.to_dict() for msg in messages]) |
| |
|
| | response = self.llm_client.chat( |
| | messages=chat_messages, |
| | model=self.model, |
| | temperature=temp, |
| | ) |
| | else: |
| | raise ValueError("Either prompt or messages must be provided") |
| |
|
| | logger.debug(f"Agent {self.name} received LLM response: {len(response)} chars") |
| | return response |
| |
|
| | async def execute_tool(self, tool_name: str, **kwargs) -> ToolResult: |
| | """ |
| | Execute a tool by name. |
| | |
| | Args: |
| | tool_name: Name of tool to execute |
| | **kwargs: Tool parameters |
| | |
| | Returns: |
| | ToolResult from tool execution |
| | """ |
| | |
| | tool = self.tools.get(tool_name) |
| |
|
| | |
| | if tool is None and self.tool_registry: |
| | tool = self.tool_registry.get_tool(tool_name) |
| |
|
| | if tool is None: |
| | logger.error(f"Tool not found: {tool_name}") |
| | return ToolResult( |
| | success=False, |
| | output=None, |
| | error=f"Tool not found: {tool_name}", |
| | ) |
| |
|
| | logger.info(f"Agent {self.name} executing tool: {tool_name}") |
| | result = await tool.safe_execute(**kwargs) |
| |
|
| | return result |
| |
|
| | def add_message(self, message: Message): |
| | """ |
| | Add a message to the agent's history. |
| | |
| | Args: |
| | message: Message to add |
| | """ |
| | self.messages.append(message) |
| |
|
| | async def receive_message(self, message: Message) -> Optional[str]: |
| | """ |
| | Receive and process a message from another agent or user. |
| | |
| | Args: |
| | message: Incoming message |
| | |
| | Returns: |
| | Response or None |
| | """ |
| | logger.info(f"Agent {self.name} received message from {message.sender}") |
| | self.add_message(message) |
| |
|
| | |
| | return await self.process_message(message) |
| |
|
| | async def process_message(self, message: Message) -> Optional[str]: |
| | """ |
| | Process an incoming message. Can be overridden by subclasses. |
| | |
| | Args: |
| | message: Message to process |
| | |
| | Returns: |
| | Response or None |
| | """ |
| | |
| | response = await self.call_llm(messages=self.messages) |
| |
|
| | |
| | self.add_message( |
| | Message( |
| | role="assistant", |
| | content=response, |
| | sender=self.name, |
| | ) |
| | ) |
| |
|
| | return response |
| |
|
| | @abstractmethod |
| | async def process_task(self, task: Task) -> Task: |
| | """ |
| | Process a task. Must be implemented by subclasses. |
| | |
| | Args: |
| | task: Task to process |
| | |
| | Returns: |
| | Updated task with results |
| | """ |
| | pass |
| |
|
| | async def send_message(self, recipient: "BaseAgent", content: str) -> Optional[str]: |
| | """ |
| | Send a message to another agent. |
| | |
| | Args: |
| | recipient: Recipient agent |
| | content: Message content |
| | |
| | Returns: |
| | Response from recipient |
| | """ |
| | message = Message( |
| | role="agent", |
| | content=content, |
| | sender=self.name, |
| | ) |
| |
|
| | logger.info(f"Agent {self.name} sending message to {recipient.name}") |
| | response = await recipient.receive_message(message) |
| |
|
| | return response |
| |
|
| | def get_available_tools(self) -> List[str]: |
| | """ |
| | Get list of available tool names. |
| | |
| | Returns: |
| | List of tool names |
| | """ |
| | tool_names = list(self.tools.keys()) |
| |
|
| | if self.tool_registry: |
| | tool_names.extend(self.tool_registry.list_tools()) |
| |
|
| | return list(set(tool_names)) |
| |
|
| | def get_tool_schemas(self) -> List[Dict[str, Any]]: |
| | """ |
| | Get schemas for all available tools. |
| | |
| | Returns: |
| | List of tool schemas |
| | """ |
| | schemas = [tool.get_schema() for tool in self.tools.values()] |
| |
|
| | if self.tool_registry: |
| | schemas.extend(self.tool_registry.get_schemas()) |
| |
|
| | return schemas |
| |
|
| | def clear_history(self): |
| | """Clear message history.""" |
| | self.messages.clear() |
| | logger.info(f"Agent {self.name} cleared message history") |
| |
|
| | def get_stats(self) -> Dict[str, Any]: |
| | """ |
| | Get agent statistics. |
| | |
| | Returns: |
| | Dictionary with agent stats |
| | """ |
| | return { |
| | "name": self.name, |
| | "model": self.model, |
| | "messages_count": len(self.messages), |
| | "tools_count": len(self.tools), |
| | } |
| |
|
| | def __repr__(self) -> str: |
| | return f"<Agent: {self.name} (model={self.model}, tools={len(self.tools)})>" |
| |
|