sql_env / training /rollout.py
hjerpe's picture
Upload folder using huggingface_hub
5dd1bb4 verified
"""Rollout utilities for GRPO training."""
from collections.abc import Sequence
import logging
from typing import Any
import uuid
try:
from sql_env.models import SQLAction, SQLObservation
except ImportError:
from models import SQLAction, SQLObservation
from sql_env.server.sql_environment import SQLEnvironment
try:
from sql_env.training.prompts import format_observation, get_system_prompt
except ImportError:
from training.prompts import format_observation, get_system_prompt
_ACTION_TYPES = ("DESCRIBE", "SAMPLE", "QUERY", "ANSWER")
_MAX_HISTORY_PAIRS = 3
_LOGGER = logging.getLogger(__name__)
def _parse_action_line(line: str) -> SQLAction | None:
"""Parse one line into a structured action.
Parameters
----------
line
Candidate line that may contain a model action.
Returns
-------
SQLAction | None
Parsed action when line matches supported action syntax,
otherwise ``None``.
"""
stripped = line.strip()
if not stripped:
return None
upper = stripped.upper()
for action_type in _ACTION_TYPES:
if not upper.startswith(action_type):
continue
remainder = stripped[len(action_type) :].lstrip()
if remainder.startswith(":"):
remainder = remainder[1:].lstrip()
if not remainder:
return None
return SQLAction(action_type=action_type, argument=remainder)
return None
def parse_model_output(text: str | None) -> SQLAction:
"""Extract an ``SQLAction`` from free-form model output.
The parser accepts both ``ACTION argument`` and ``ACTION: argument``
formats (case-insensitive), scans multi-line output, and falls back to
``QUERY`` with raw text when parsing fails.
Parameters
----------
text
Raw model output text.
Returns
-------
SQLAction
Parsed structured action, or a ``QUERY`` fallback action.
"""
raw_text = "" if text is None else str(text)
for line in raw_text.splitlines():
parsed = _parse_action_line(line)
if parsed is not None:
return parsed
parsed = _parse_action_line(raw_text)
if parsed is not None:
return parsed
_LOGGER.warning("Unparseable model output; falling back to QUERY action")
return SQLAction(action_type="QUERY", argument=raw_text)
def _build_environment(config: Any, tokenizer: Any) -> SQLEnvironment:
"""Construct a local SQL environment instance for training rollouts."""
return SQLEnvironment(
questions_path=config.questions_path,
db_dir=config.db_dir,
tokenizer=tokenizer,
step_budget=config.step_budget,
)
def _trim_history(history_pairs: list[tuple[str, str]]) -> list[tuple[str, str]]:
"""Keep only the most recent observation/action pairs."""
if len(history_pairs) <= _MAX_HISTORY_PAIRS:
return history_pairs
return history_pairs[-_MAX_HISTORY_PAIRS:]
def _build_messages(
question_text: str,
observation: SQLObservation,
history_pairs: list[tuple[str, str]],
) -> list[dict[str, str]]:
"""Build chat messages for one model generation step."""
current_observation = format_observation(observation)
messages: list[dict[str, str]] = [
{"role": "system", "content": get_system_prompt()}
]
for prior_observation, prior_action in _trim_history(history_pairs):
messages.append({"role": "user", "content": prior_observation})
messages.append({"role": "assistant", "content": prior_action})
messages.append(
{
"role": "user",
"content": f"Training Question: {question_text}\n\n{current_observation}",
}
)
return messages
def _extract_generated_text(generated: Any, tokenizer: Any) -> str:
"""Normalize model.generate output into plain text."""
if hasattr(generated, "tolist"):
generated = generated.tolist()
if isinstance(generated, str):
return generated.strip()
if isinstance(generated, Sequence) and generated:
first_item = generated[0]
if isinstance(first_item, str):
return first_item.strip()
if hasattr(tokenizer, "decode"):
return str(tokenizer.decode(first_item, skip_special_tokens=True)).strip()
if hasattr(tokenizer, "decode"):
try:
return str(tokenizer.decode(generated, skip_special_tokens=True)).strip()
except (TypeError, ValueError):
return str(generated).strip()
return str(generated).strip()
def _generate_action_text(
messages: list[dict[str, str]], model: Any, tokenizer: Any, config: Any
) -> str:
"""Render chat messages and ask the model for the next action."""
rendered_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
if callable(getattr(tokenizer, "__call__", None)):
tokenized = tokenizer(rendered_prompt, return_tensors="pt")
if isinstance(tokenized, dict) and "input_ids" in tokenized:
try:
model_device = next(model.parameters()).device
prepared_inputs = {
key: value.to(model_device) if hasattr(value, "to") else value
for key, value in tokenized.items()
}
except (StopIteration, AttributeError, TypeError):
prepared_inputs = tokenized
generated = model.generate(
**prepared_inputs,
max_new_tokens=config.max_new_tokens,
)
input_ids = prepared_inputs.get("input_ids")
generated_values = (
generated.tolist() if hasattr(generated, "tolist") else generated
)
input_values = (
input_ids.tolist() if hasattr(input_ids, "tolist") else input_ids
)
if (
isinstance(generated_values, Sequence)
and generated_values
and isinstance(input_values, Sequence)
and input_values
and hasattr(tokenizer, "decode")
):
generated_first = generated_values[0]
input_first = input_values[0]
if isinstance(generated_first, Sequence) and isinstance(
input_first, Sequence
):
new_tokens = generated_first[len(input_first) :]
return str(
tokenizer.decode(new_tokens, skip_special_tokens=True)
).strip()
return _extract_generated_text(generated, tokenizer)
generated = model.generate(rendered_prompt, max_new_tokens=config.max_new_tokens)
return _extract_generated_text(generated, tokenizer)
def _reset_for_prompt(env: Any, question_text: str, seed: int | None) -> SQLObservation:
"""Reset environment while preferring the requested question when possible."""
questions = getattr(env, "questions", None)
if not isinstance(questions, list):
return env.reset(seed=seed)
matching_questions = [
question
for question in questions
if getattr(question, "question_text", None) == question_text
]
if not matching_questions:
return env.reset(seed=seed)
original_questions = list(questions)
try:
env.questions = matching_questions
return env.reset(seed=seed)
finally:
env.questions = original_questions
def play_episode(
question_text: str,
model: Any,
tokenizer: Any,
config: Any,
env: Any,
episode_seed: int | None = None,
) -> dict[str, Any]:
"""Run one environment episode and collect completion + metadata."""
observation = _reset_for_prompt(env, question_text, seed=episode_seed)
history_pairs: list[tuple[str, str]] = []
action_lines: list[str] = []
seen_actions: set[str] = set()
operational_signals: list[dict[str, bool]] = []
cumulative_progress = 0.0
answer_correct = False
for _ in range(config.step_budget):
formatted_observation = format_observation(observation)
messages = _build_messages(
question_text=question_text,
observation=observation,
history_pairs=history_pairs,
)
model_output = _generate_action_text(messages, model, tokenizer, config)
action = parse_model_output(model_output)
action_line = f"{action.action_type}: {action.argument}"
action_key = f"{action.action_type}|{action.argument}"
is_repeat = action_key in seen_actions
seen_actions.add(action_key)
observation = env.step(action)
if action.action_type == "QUERY" and observation.reward is not None:
cumulative_progress += max(0.0, float(observation.reward))
action_lines.append(action_line)
history_pairs.append((formatted_observation, action_line))
signal = {
"exec_ok": not bool(observation.error),
"new_info": action.action_type in {"DESCRIBE", "SAMPLE"}
and not bool(observation.error),
"repeat": is_repeat,
}
operational_signals.append(signal)
if action.action_type == "ANSWER":
normalized_result = observation.result.strip().lower()
answer_correct = normalized_result.startswith("answer submitted: correct")
if observation.done:
break
operational_score = float(
sum(1.0 for signal in operational_signals if signal["exec_ok"])
- sum(1.0 for signal in operational_signals if signal["repeat"])
)
metadata = {
"episode_id": getattr(getattr(env, "state", None), "episode_id", None)
or str(uuid.uuid4()),
"step_count": len(action_lines),
"done": bool(observation.done),
"answer_correct": answer_correct,
"cumulative_progress": cumulative_progress,
"operational_signals": operational_signals,
}
completion_text = "\n".join(action_lines)
return {
"prompt": question_text,
"completion": completion_text,
"content": completion_text,
"metadata": metadata,
"correct": answer_correct,
"progress": cumulative_progress,
"operational": operational_score,
}
def rollout_func(
prompts: list[str],
model: Any,
tokenizer: Any,
config: Any,
) -> list[dict[str, Any]]:
"""Play SQLEnv episodes for a batch of prompt strings."""
env = _build_environment(config, tokenizer)
rollouts: list[dict[str, Any]] = []
for idx, prompt in enumerate(prompts):
episode_seed = (
None if getattr(config, "seed", None) is None else int(config.seed) + idx
)
rollouts.append(
play_episode(
question_text=prompt,
model=model,
tokenizer=tokenizer,
config=config,
env=env,
episode_seed=episode_seed,
)
)
return rollouts