scrapeRL / backend /app /core /action.py
NeerajCodz's picture
feat: add core RL environment models (observation, action, reward, env)
ab65628
"""Action model for the RL environment."""
from enum import Enum
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_validator
class ActionType(str, Enum):
"""All possible action types in the environment."""
# Navigation actions
NAVIGATE = "navigate"
GO_BACK = "go_back"
GO_FORWARD = "go_forward"
REFRESH = "refresh"
# Interaction actions
CLICK = "click"
FILL = "fill"
SELECT = "select"
SCROLL = "scroll"
HOVER = "hover"
# Extraction actions
EXTRACT_FIELD = "extract_field"
EXTRACT_TABLE = "extract_table"
EXTRACT_LIST = "extract_list"
# Search actions
SEARCH_PAGE = "search_page"
SEARCH_ENGINE = "search_engine"
# Verification actions
VERIFY_FACT = "verify_fact"
VERIFY_FIELD = "verify_field"
# Memory actions
STORE_MEMORY = "store_memory"
RECALL_MEMORY = "recall_memory"
# Tool actions
MCP_TOOL_CALL = "mcp_tool_call"
# Planning actions
CREATE_PLAN = "create_plan"
UPDATE_PLAN = "update_plan"
# Communication actions
SEND_MESSAGE = "send_message"
# Control actions
WAIT = "wait"
DONE = "done"
FAIL = "fail"
class NavigateParams(BaseModel):
"""Parameters for navigation actions."""
url: str
wait_for: str | None = None
timeout_ms: int = 30000
class ClickParams(BaseModel):
"""Parameters for click actions."""
selector: str
button: str = "left"
click_count: int = 1
wait_after_ms: int = 500
class FillParams(BaseModel):
"""Parameters for form fill actions."""
selector: str
value: str
clear_first: bool = True
class SelectParams(BaseModel):
"""Parameters for select dropdown actions."""
selector: str
value: str | None = None
label: str | None = None
index: int | None = None
class ScrollParams(BaseModel):
"""Parameters for scroll actions."""
direction: str = "down"
amount: int | str = "page"
selector: str | None = None
class ExtractFieldParams(BaseModel):
"""Parameters for field extraction actions."""
field_name: str
selector: str | None = None
extraction_method: str = "text"
attribute: str | None = None
regex_pattern: str | None = None
post_process: str | None = None
class ExtractTableParams(BaseModel):
"""Parameters for table extraction actions."""
table_selector: str
headers: list[str] | None = None
row_selector: str | None = None
cell_selectors: dict[str, str] | None = None
class ExtractListParams(BaseModel):
"""Parameters for list extraction actions."""
container_selector: str
item_selector: str
field_selectors: dict[str, str]
class SearchPageParams(BaseModel):
"""Parameters for searching within the current page."""
query: str
search_type: str = "text"
class SearchEngineParams(BaseModel):
"""Parameters for search engine queries."""
query: str
engine: str = "google"
num_results: int = 10
class VerifyFactParams(BaseModel):
"""Parameters for fact verification."""
claim: str
sources: list[str] | None = None
confidence_threshold: float = 0.8
class VerifyFieldParams(BaseModel):
"""Parameters for field verification."""
field_name: str
expected_type: str | None = None
expected_format: str | None = None
validation_rules: list[str] = Field(default_factory=list)
class MemoryParams(BaseModel):
"""Parameters for memory operations."""
key: str
value: Any | None = None
memory_type: str = "working"
ttl_seconds: int | None = None
class MCPToolCallParams(BaseModel):
"""Parameters for MCP tool calls."""
tool_name: str
arguments: dict[str, Any] = Field(default_factory=dict)
class PlanParams(BaseModel):
"""Parameters for planning actions."""
plan_description: str | None = None
steps: list[dict[str, Any]] | None = None
class MessageParams(BaseModel):
"""Parameters for inter-agent messages."""
target_agent: str
message_type: str
content: dict[str, Any] = Field(default_factory=dict)
class WaitParams(BaseModel):
"""Parameters for wait actions."""
duration_ms: int = 1000
wait_for_selector: str | None = None
wait_for_navigation: bool = False
class DoneParams(BaseModel):
"""Parameters for completion."""
success: bool = True
message: str | None = None
final_result: dict[str, Any] | None = None
class Action(BaseModel):
"""
Represents an action to be taken in the environment.
An action consists of:
- action_type: The type of action
- parameters: Action-specific parameters
- reasoning: Why this action was chosen
- confidence: How confident the agent is
"""
action_type: ActionType = Field(..., description="Type of action to execute")
parameters: dict[str, Any] = Field(
default_factory=dict,
description="Action-specific parameters",
)
reasoning: str | None = Field(
default=None,
description="Agent's reasoning for this action",
)
confidence: float = Field(
default=1.0,
ge=0.0,
le=1.0,
description="Confidence in this action (0-1)",
)
agent_id: str | None = Field(
default=None,
description="ID of the agent that produced this action",
)
plan_step: int | None = Field(
default=None,
description="Which step of the plan this corresponds to",
)
@field_validator("confidence")
@classmethod
def validate_confidence(cls, v: float) -> float:
"""Ensure confidence is between 0 and 1."""
return max(0.0, min(1.0, v))
model_config = ConfigDict(
json_schema_extra={
"example": {
"action_type": "extract_field",
"parameters": {
"field_name": "price",
"selector": ".product-price",
"extraction_method": "text",
},
"reasoning": "The price element is visible with class .product-price",
"confidence": 0.92,
}
}
)
@classmethod
def navigate(cls, url: str, **kwargs: Any) -> "Action":
"""Create a navigate action."""
return cls(
action_type=ActionType.NAVIGATE,
parameters={"url": url, **kwargs},
)
@classmethod
def click(cls, selector: str, **kwargs: Any) -> "Action":
"""Create a click action."""
return cls(
action_type=ActionType.CLICK,
parameters={"selector": selector, **kwargs},
)
@classmethod
def extract_field(
cls,
field_name: str,
selector: str | None = None,
**kwargs: Any,
) -> "Action":
"""Create an extract field action."""
return cls(
action_type=ActionType.EXTRACT_FIELD,
parameters={"field_name": field_name, "selector": selector, **kwargs},
)
@classmethod
def search_engine(cls, query: str, engine: str = "google", **kwargs: Any) -> "Action":
"""Create a search engine action."""
return cls(
action_type=ActionType.SEARCH_ENGINE,
parameters={"query": query, "engine": engine, **kwargs},
)
@classmethod
def done(cls, success: bool = True, message: str | None = None) -> "Action":
"""Create a done action."""
return cls(
action_type=ActionType.DONE,
parameters={"success": success, "message": message},
)
@classmethod
def wait(cls, duration_ms: int = 1000) -> "Action":
"""Create a wait action."""
return cls(
action_type=ActionType.WAIT,
parameters={"duration_ms": duration_ms},
)
@classmethod
def mcp_tool_call(cls, tool_name: str, **arguments: Any) -> "Action":
"""Create an MCP tool call action."""
return cls(
action_type=ActionType.MCP_TOOL_CALL,
parameters={"tool_name": tool_name, "arguments": arguments},
)
def get_param(self, key: str, default: Any = None) -> Any:
"""Get a parameter value with optional default."""
return self.parameters.get(key, default)
def validate_params(self) -> list[str]:
"""Validate parameters for this action type. Returns list of errors."""
errors = []
required_params = {
ActionType.NAVIGATE: ["url"],
ActionType.CLICK: ["selector"],
ActionType.FILL: ["selector", "value"],
ActionType.EXTRACT_FIELD: ["field_name"],
ActionType.SEARCH_ENGINE: ["query"],
ActionType.MCP_TOOL_CALL: ["tool_name"],
ActionType.SEND_MESSAGE: ["target_agent", "message_type"],
}
if self.action_type in required_params:
for param in required_params[self.action_type]:
if param not in self.parameters or self.parameters[param] is None:
errors.append(f"Missing required parameter: {param}")
return errors