| """Rollout utilities for GRPO training.""" |
|
|
| from collections.abc import Sequence |
| import logging |
| from typing import Any |
| import uuid |
|
|
| try: |
| from sql_env.models import SQLAction, SQLObservation |
| except ImportError: |
| from models import SQLAction, SQLObservation |
|
|
| from sql_env.server.sql_environment import SQLEnvironment |
|
|
| try: |
| from sql_env.training.prompts import format_observation, get_system_prompt |
| except ImportError: |
| from training.prompts import format_observation, get_system_prompt |
|
|
| _ACTION_TYPES = ("DESCRIBE", "SAMPLE", "QUERY", "ANSWER") |
| _MAX_HISTORY_PAIRS = 3 |
| _LOGGER = logging.getLogger(__name__) |
|
|
|
|
| def _parse_action_line(line: str) -> SQLAction | None: |
| """Parse one line into a structured action. |
| |
| Parameters |
| ---------- |
| line |
| Candidate line that may contain a model action. |
| |
| Returns |
| ------- |
| SQLAction | None |
| Parsed action when line matches supported action syntax, |
| otherwise ``None``. |
| """ |
|
|
| stripped = line.strip() |
| if not stripped: |
| return None |
|
|
| upper = stripped.upper() |
| for action_type in _ACTION_TYPES: |
| if not upper.startswith(action_type): |
| continue |
|
|
| remainder = stripped[len(action_type) :].lstrip() |
| if remainder.startswith(":"): |
| remainder = remainder[1:].lstrip() |
|
|
| if not remainder: |
| return None |
|
|
| return SQLAction(action_type=action_type, argument=remainder) |
|
|
| return None |
|
|
|
|
| def parse_model_output(text: str | None) -> SQLAction: |
| """Extract an ``SQLAction`` from free-form model output. |
| |
| The parser accepts both ``ACTION argument`` and ``ACTION: argument`` |
| formats (case-insensitive), scans multi-line output, and falls back to |
| ``QUERY`` with raw text when parsing fails. |
| |
| Parameters |
| ---------- |
| text |
| Raw model output text. |
| |
| Returns |
| ------- |
| SQLAction |
| Parsed structured action, or a ``QUERY`` fallback action. |
| """ |
|
|
| raw_text = "" if text is None else str(text) |
|
|
| for line in raw_text.splitlines(): |
| parsed = _parse_action_line(line) |
| if parsed is not None: |
| return parsed |
|
|
| parsed = _parse_action_line(raw_text) |
| if parsed is not None: |
| return parsed |
|
|
| _LOGGER.warning("Unparseable model output; falling back to QUERY action") |
| return SQLAction(action_type="QUERY", argument=raw_text) |
|
|
|
|
| def _build_environment(config: Any, tokenizer: Any) -> SQLEnvironment: |
| """Construct a local SQL environment instance for training rollouts.""" |
|
|
| return SQLEnvironment( |
| questions_path=config.questions_path, |
| db_dir=config.db_dir, |
| tokenizer=tokenizer, |
| step_budget=config.step_budget, |
| ) |
|
|
|
|
| def _trim_history(history_pairs: list[tuple[str, str]]) -> list[tuple[str, str]]: |
| """Keep only the most recent observation/action pairs.""" |
|
|
| if len(history_pairs) <= _MAX_HISTORY_PAIRS: |
| return history_pairs |
| return history_pairs[-_MAX_HISTORY_PAIRS:] |
|
|
|
|
| def _build_messages( |
| question_text: str, |
| observation: SQLObservation, |
| history_pairs: list[tuple[str, str]], |
| ) -> list[dict[str, str]]: |
| """Build chat messages for one model generation step.""" |
|
|
| current_observation = format_observation(observation) |
| messages: list[dict[str, str]] = [ |
| {"role": "system", "content": get_system_prompt()} |
| ] |
|
|
| for prior_observation, prior_action in _trim_history(history_pairs): |
| messages.append({"role": "user", "content": prior_observation}) |
| messages.append({"role": "assistant", "content": prior_action}) |
|
|
| messages.append( |
| { |
| "role": "user", |
| "content": f"Training Question: {question_text}\n\n{current_observation}", |
| } |
| ) |
| return messages |
|
|
|
|
| def _extract_generated_text(generated: Any, tokenizer: Any) -> str: |
| """Normalize model.generate output into plain text.""" |
|
|
| if hasattr(generated, "tolist"): |
| generated = generated.tolist() |
|
|
| if isinstance(generated, str): |
| return generated.strip() |
|
|
| if isinstance(generated, Sequence) and generated: |
| first_item = generated[0] |
| if isinstance(first_item, str): |
| return first_item.strip() |
| if hasattr(tokenizer, "decode"): |
| return str(tokenizer.decode(first_item, skip_special_tokens=True)).strip() |
|
|
| if hasattr(tokenizer, "decode"): |
| try: |
| return str(tokenizer.decode(generated, skip_special_tokens=True)).strip() |
| except (TypeError, ValueError): |
| return str(generated).strip() |
|
|
| return str(generated).strip() |
|
|
|
|
| def _generate_action_text( |
| messages: list[dict[str, str]], model: Any, tokenizer: Any, config: Any |
| ) -> str: |
| """Render chat messages and ask the model for the next action.""" |
|
|
| rendered_prompt = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| if callable(getattr(tokenizer, "__call__", None)): |
| tokenized = tokenizer(rendered_prompt, return_tensors="pt") |
| if isinstance(tokenized, dict) and "input_ids" in tokenized: |
| try: |
| model_device = next(model.parameters()).device |
| prepared_inputs = { |
| key: value.to(model_device) if hasattr(value, "to") else value |
| for key, value in tokenized.items() |
| } |
| except (StopIteration, AttributeError, TypeError): |
| prepared_inputs = tokenized |
|
|
| generated = model.generate( |
| **prepared_inputs, |
| max_new_tokens=config.max_new_tokens, |
| ) |
|
|
| input_ids = prepared_inputs.get("input_ids") |
| generated_values = ( |
| generated.tolist() if hasattr(generated, "tolist") else generated |
| ) |
| input_values = ( |
| input_ids.tolist() if hasattr(input_ids, "tolist") else input_ids |
| ) |
| if ( |
| isinstance(generated_values, Sequence) |
| and generated_values |
| and isinstance(input_values, Sequence) |
| and input_values |
| and hasattr(tokenizer, "decode") |
| ): |
| generated_first = generated_values[0] |
| input_first = input_values[0] |
| if isinstance(generated_first, Sequence) and isinstance( |
| input_first, Sequence |
| ): |
| new_tokens = generated_first[len(input_first) :] |
| return str( |
| tokenizer.decode(new_tokens, skip_special_tokens=True) |
| ).strip() |
|
|
| return _extract_generated_text(generated, tokenizer) |
|
|
| generated = model.generate(rendered_prompt, max_new_tokens=config.max_new_tokens) |
| return _extract_generated_text(generated, tokenizer) |
|
|
|
|
| def _reset_for_prompt(env: Any, question_text: str, seed: int | None) -> SQLObservation: |
| """Reset environment while preferring the requested question when possible.""" |
|
|
| questions = getattr(env, "questions", None) |
| if not isinstance(questions, list): |
| return env.reset(seed=seed) |
|
|
| matching_questions = [ |
| question |
| for question in questions |
| if getattr(question, "question_text", None) == question_text |
| ] |
| if not matching_questions: |
| return env.reset(seed=seed) |
|
|
| original_questions = list(questions) |
| try: |
| env.questions = matching_questions |
| return env.reset(seed=seed) |
| finally: |
| env.questions = original_questions |
|
|
|
|
| def play_episode( |
| question_text: str, |
| model: Any, |
| tokenizer: Any, |
| config: Any, |
| env: Any, |
| episode_seed: int | None = None, |
| ) -> dict[str, Any]: |
| """Run one environment episode and collect completion + metadata.""" |
|
|
| observation = _reset_for_prompt(env, question_text, seed=episode_seed) |
| history_pairs: list[tuple[str, str]] = [] |
| action_lines: list[str] = [] |
| seen_actions: set[str] = set() |
| operational_signals: list[dict[str, bool]] = [] |
| cumulative_progress = 0.0 |
| answer_correct = False |
|
|
| for _ in range(config.step_budget): |
| formatted_observation = format_observation(observation) |
| messages = _build_messages( |
| question_text=question_text, |
| observation=observation, |
| history_pairs=history_pairs, |
| ) |
| model_output = _generate_action_text(messages, model, tokenizer, config) |
| action = parse_model_output(model_output) |
| action_line = f"{action.action_type}: {action.argument}" |
|
|
| action_key = f"{action.action_type}|{action.argument}" |
| is_repeat = action_key in seen_actions |
| seen_actions.add(action_key) |
|
|
| observation = env.step(action) |
| if action.action_type == "QUERY" and observation.reward is not None: |
| cumulative_progress += max(0.0, float(observation.reward)) |
|
|
| action_lines.append(action_line) |
| history_pairs.append((formatted_observation, action_line)) |
|
|
| signal = { |
| "exec_ok": not bool(observation.error), |
| "new_info": action.action_type in {"DESCRIBE", "SAMPLE"} |
| and not bool(observation.error), |
| "repeat": is_repeat, |
| } |
| operational_signals.append(signal) |
|
|
| if action.action_type == "ANSWER": |
| normalized_result = observation.result.strip().lower() |
| answer_correct = normalized_result.startswith("answer submitted: correct") |
|
|
| if observation.done: |
| break |
|
|
| operational_score = float( |
| sum(1.0 for signal in operational_signals if signal["exec_ok"]) |
| - sum(1.0 for signal in operational_signals if signal["repeat"]) |
| ) |
|
|
| metadata = { |
| "episode_id": getattr(getattr(env, "state", None), "episode_id", None) |
| or str(uuid.uuid4()), |
| "step_count": len(action_lines), |
| "done": bool(observation.done), |
| "answer_correct": answer_correct, |
| "cumulative_progress": cumulative_progress, |
| "operational_signals": operational_signals, |
| } |
|
|
| completion_text = "\n".join(action_lines) |
| return { |
| "prompt": question_text, |
| "completion": completion_text, |
| "content": completion_text, |
| "metadata": metadata, |
| "correct": answer_correct, |
| "progress": cumulative_progress, |
| "operational": operational_score, |
| } |
|
|
|
|
| def rollout_func( |
| prompts: list[str], |
| model: Any, |
| tokenizer: Any, |
| config: Any, |
| ) -> list[dict[str, Any]]: |
| """Play SQLEnv episodes for a batch of prompt strings.""" |
|
|
| env = _build_environment(config, tokenizer) |
| rollouts: list[dict[str, Any]] = [] |
| for idx, prompt in enumerate(prompts): |
| episode_seed = ( |
| None if getattr(config, "seed", None) is None else int(config.seed) + idx |
| ) |
| rollouts.append( |
| play_episode( |
| question_text=prompt, |
| model=model, |
| tokenizer=tokenizer, |
| config=config, |
| env=env, |
| episode_seed=episode_seed, |
| ) |
| ) |
| return rollouts |
|
|