File size: 4,812 Bytes
5dd1bb4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | from typing import Any, Dict, Iterable
import torch
from openenv.core.client_types import StepResult
from openenv.core.env_server.interfaces import Message
from openenv.core.env_client import EnvClient
from .models import SQLAction, SQLObservation, SQLState
class SQLEnvClient(EnvClient[SQLAction, SQLObservation, SQLState]):
"""Client for interacting with the SQLEnv environment server."""
def _step_payload(self, action: SQLAction) -> Dict[str, Any]:
"""Convert a SQLAction into the payload for the step endpoint."""
return {
"action_type": action.action_type,
"argument": action.argument,
"metadata": action.metadata,
}
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[SQLObservation]:
"""Parse the response from the step endpoint into a StepResult."""
obs_data = payload.get("observation")
if not isinstance(obs_data, dict):
obs_data = payload
done = payload.get("done", obs_data.get("done", False))
reward = payload.get("reward", obs_data.get("reward"))
observation = SQLObservation(
question=str(obs_data.get("question", "")),
schema_info=str(obs_data.get("schema_info", "")),
result=str(obs_data.get("result", "")),
error=str(obs_data.get("error", "")),
step_count=int(obs_data.get("step_count", 0)),
budget_remaining=int(obs_data.get("budget_remaining", 0)),
action_history=list(obs_data.get("action_history", [])),
done=bool(done),
reward=reward,
metadata=obs_data.get("metadata", {}),
)
return StepResult(
observation=observation,
reward=reward,
done=bool(done),
)
def _parse_state(self, payload: Dict[str, Any]) -> SQLState:
# Parse history messages
history_messages = payload.get("history_messages", [])
# Parse history tokens - convert lists back to tensors
history_tokens_data = payload.get("history_tokens", [])
history_tokens = []
for token_list in history_tokens_data:
if token_list:
history_tokens.append(torch.tensor(token_list))
else:
history_tokens.append(torch.tensor([]))
return SQLState(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
history_messages=history_messages,
history_tokens=history_tokens,
current_action_type=payload.get("current_action_type", "query"),
)
def _detect_action_type(self, message_content: str) -> str:
"""Detect the action type from user message content."""
content_lower = message_content.lower()
if content_lower.startswith("answer "):
return "ANSWER"
describe_keywords = [
"describe",
"schema",
"columns",
"structure",
"what columns",
"show columns",
]
if any(keyword in content_lower for keyword in describe_keywords):
return "DESCRIBE"
sample_keywords = [
"sample",
"example",
"rows",
"data",
"show me",
"few rows",
"how many",
]
if any(keyword in content_lower for keyword in sample_keywords):
return "SAMPLE"
return "QUERY"
def message_to_action(
self,
message: Message,
tokenizer: Any,
history_messages: Iterable[Message] | None = None,
) -> SQLAction:
"""Convert a user Message into a SQLAction."""
if "role" not in message:
raise ValueError("Message must contain a 'role' key")
if "content" not in message:
raise ValueError("Message must contain a 'content' key")
if message["content"] is None:
raise ValueError("Message content cannot be None")
_ = tokenizer
_ = history_messages
content = str(message["content"])
parsed = content.strip()
action_type = "QUERY"
argument = content
if message["role"].lower() == "user" and parsed:
prefix, separator, remainder = parsed.partition(" ")
normalized_prefix = prefix.upper()
if normalized_prefix in {"DESCRIBE", "SAMPLE", "QUERY", "ANSWER"}:
action_type = normalized_prefix
argument = remainder if separator else ""
else:
action_type = self._detect_action_type(parsed)
argument = parsed
return SQLAction(
action_type=action_type,
argument=argument,
)
|