File size: 4,812 Bytes
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from typing import Any, Dict, Iterable

import torch
from openenv.core.client_types import StepResult

from openenv.core.env_server.interfaces import Message
from openenv.core.env_client import EnvClient

from .models import SQLAction, SQLObservation, SQLState


class SQLEnvClient(EnvClient[SQLAction, SQLObservation, SQLState]):
    """Client for interacting with the SQLEnv environment server."""

    def _step_payload(self, action: SQLAction) -> Dict[str, Any]:
        """Convert a SQLAction into the payload for the step endpoint."""
        return {
            "action_type": action.action_type,
            "argument": action.argument,
            "metadata": action.metadata,
        }

    def _parse_result(self, payload: Dict[str, Any]) -> StepResult[SQLObservation]:
        """Parse the response from the step endpoint into a StepResult."""

        obs_data = payload.get("observation")
        if not isinstance(obs_data, dict):
            obs_data = payload

        done = payload.get("done", obs_data.get("done", False))
        reward = payload.get("reward", obs_data.get("reward"))

        observation = SQLObservation(
            question=str(obs_data.get("question", "")),
            schema_info=str(obs_data.get("schema_info", "")),
            result=str(obs_data.get("result", "")),
            error=str(obs_data.get("error", "")),
            step_count=int(obs_data.get("step_count", 0)),
            budget_remaining=int(obs_data.get("budget_remaining", 0)),
            action_history=list(obs_data.get("action_history", [])),
            done=bool(done),
            reward=reward,
            metadata=obs_data.get("metadata", {}),
        )

        return StepResult(
            observation=observation,
            reward=reward,
            done=bool(done),
        )

    def _parse_state(self, payload: Dict[str, Any]) -> SQLState:
        # Parse history messages
        history_messages = payload.get("history_messages", [])

        # Parse history tokens - convert lists back to tensors
        history_tokens_data = payload.get("history_tokens", [])
        history_tokens = []
        for token_list in history_tokens_data:
            if token_list:
                history_tokens.append(torch.tensor(token_list))
            else:
                history_tokens.append(torch.tensor([]))

        return SQLState(
            episode_id=payload.get("episode_id"),
            step_count=payload.get("step_count", 0),
            history_messages=history_messages,
            history_tokens=history_tokens,
            current_action_type=payload.get("current_action_type", "query"),
        )

    def _detect_action_type(self, message_content: str) -> str:
        """Detect the action type from user message content."""
        content_lower = message_content.lower()

        if content_lower.startswith("answer "):
            return "ANSWER"

        describe_keywords = [
            "describe",
            "schema",
            "columns",
            "structure",
            "what columns",
            "show columns",
        ]
        if any(keyword in content_lower for keyword in describe_keywords):
            return "DESCRIBE"

        sample_keywords = [
            "sample",
            "example",
            "rows",
            "data",
            "show me",
            "few rows",
            "how many",
        ]
        if any(keyword in content_lower for keyword in sample_keywords):
            return "SAMPLE"

        return "QUERY"

    def message_to_action(
        self,
        message: Message,
        tokenizer: Any,
        history_messages: Iterable[Message] | None = None,
    ) -> SQLAction:
        """Convert a user Message into a SQLAction."""
        if "role" not in message:
            raise ValueError("Message must contain a 'role' key")
        if "content" not in message:
            raise ValueError("Message must contain a 'content' key")
        if message["content"] is None:
            raise ValueError("Message content cannot be None")

        _ = tokenizer
        _ = history_messages

        content = str(message["content"])
        parsed = content.strip()

        action_type = "QUERY"
        argument = content
        if message["role"].lower() == "user" and parsed:
            prefix, separator, remainder = parsed.partition(" ")
            normalized_prefix = prefix.upper()
            if normalized_prefix in {"DESCRIBE", "SAMPLE", "QUERY", "ANSWER"}:
                action_type = normalized_prefix
                argument = remainder if separator else ""
            else:
                action_type = self._detect_action_type(parsed)
                argument = parsed

        return SQLAction(
            action_type=action_type,
            argument=argument,
        )