"""Rollout utilities for GRPO training.""" from collections.abc import Sequence import logging from typing import Any import uuid try: from sql_env.models import SQLAction, SQLObservation except ImportError: from models import SQLAction, SQLObservation from sql_env.server.sql_environment import SQLEnvironment try: from sql_env.training.prompts import format_observation, get_system_prompt except ImportError: from training.prompts import format_observation, get_system_prompt _ACTION_TYPES = ("DESCRIBE", "SAMPLE", "QUERY", "ANSWER") _MAX_HISTORY_PAIRS = 3 _LOGGER = logging.getLogger(__name__) def _parse_action_line(line: str) -> SQLAction | None: """Parse one line into a structured action. Parameters ---------- line Candidate line that may contain a model action. Returns ------- SQLAction | None Parsed action when line matches supported action syntax, otherwise ``None``. """ stripped = line.strip() if not stripped: return None upper = stripped.upper() for action_type in _ACTION_TYPES: if not upper.startswith(action_type): continue remainder = stripped[len(action_type) :].lstrip() if remainder.startswith(":"): remainder = remainder[1:].lstrip() if not remainder: return None return SQLAction(action_type=action_type, argument=remainder) return None def parse_model_output(text: str | None) -> SQLAction: """Extract an ``SQLAction`` from free-form model output. The parser accepts both ``ACTION argument`` and ``ACTION: argument`` formats (case-insensitive), scans multi-line output, and falls back to ``QUERY`` with raw text when parsing fails. Parameters ---------- text Raw model output text. Returns ------- SQLAction Parsed structured action, or a ``QUERY`` fallback action. """ raw_text = "" if text is None else str(text) for line in raw_text.splitlines(): parsed = _parse_action_line(line) if parsed is not None: return parsed parsed = _parse_action_line(raw_text) if parsed is not None: return parsed _LOGGER.warning("Unparseable model output; falling back to QUERY action") return SQLAction(action_type="QUERY", argument=raw_text) def _build_environment(config: Any, tokenizer: Any) -> SQLEnvironment: """Construct a local SQL environment instance for training rollouts.""" return SQLEnvironment( questions_path=config.questions_path, db_dir=config.db_dir, tokenizer=tokenizer, step_budget=config.step_budget, ) def _trim_history(history_pairs: list[tuple[str, str]]) -> list[tuple[str, str]]: """Keep only the most recent observation/action pairs.""" if len(history_pairs) <= _MAX_HISTORY_PAIRS: return history_pairs return history_pairs[-_MAX_HISTORY_PAIRS:] def _build_messages( question_text: str, observation: SQLObservation, history_pairs: list[tuple[str, str]], ) -> list[dict[str, str]]: """Build chat messages for one model generation step.""" current_observation = format_observation(observation) messages: list[dict[str, str]] = [ {"role": "system", "content": get_system_prompt()} ] for prior_observation, prior_action in _trim_history(history_pairs): messages.append({"role": "user", "content": prior_observation}) messages.append({"role": "assistant", "content": prior_action}) messages.append( { "role": "user", "content": f"Training Question: {question_text}\n\n{current_observation}", } ) return messages def _extract_generated_text(generated: Any, tokenizer: Any) -> str: """Normalize model.generate output into plain text.""" if hasattr(generated, "tolist"): generated = generated.tolist() if isinstance(generated, str): return generated.strip() if isinstance(generated, Sequence) and generated: first_item = generated[0] if isinstance(first_item, str): return first_item.strip() if hasattr(tokenizer, "decode"): return str(tokenizer.decode(first_item, skip_special_tokens=True)).strip() if hasattr(tokenizer, "decode"): try: return str(tokenizer.decode(generated, skip_special_tokens=True)).strip() except (TypeError, ValueError): return str(generated).strip() return str(generated).strip() def _generate_action_text( messages: list[dict[str, str]], model: Any, tokenizer: Any, config: Any ) -> str: """Render chat messages and ask the model for the next action.""" rendered_prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) if callable(getattr(tokenizer, "__call__", None)): tokenized = tokenizer(rendered_prompt, return_tensors="pt") if isinstance(tokenized, dict) and "input_ids" in tokenized: try: model_device = next(model.parameters()).device prepared_inputs = { key: value.to(model_device) if hasattr(value, "to") else value for key, value in tokenized.items() } except (StopIteration, AttributeError, TypeError): prepared_inputs = tokenized generated = model.generate( **prepared_inputs, max_new_tokens=config.max_new_tokens, ) input_ids = prepared_inputs.get("input_ids") generated_values = ( generated.tolist() if hasattr(generated, "tolist") else generated ) input_values = ( input_ids.tolist() if hasattr(input_ids, "tolist") else input_ids ) if ( isinstance(generated_values, Sequence) and generated_values and isinstance(input_values, Sequence) and input_values and hasattr(tokenizer, "decode") ): generated_first = generated_values[0] input_first = input_values[0] if isinstance(generated_first, Sequence) and isinstance( input_first, Sequence ): new_tokens = generated_first[len(input_first) :] return str( tokenizer.decode(new_tokens, skip_special_tokens=True) ).strip() return _extract_generated_text(generated, tokenizer) generated = model.generate(rendered_prompt, max_new_tokens=config.max_new_tokens) return _extract_generated_text(generated, tokenizer) def _reset_for_prompt(env: Any, question_text: str, seed: int | None) -> SQLObservation: """Reset environment while preferring the requested question when possible.""" questions = getattr(env, "questions", None) if not isinstance(questions, list): return env.reset(seed=seed) matching_questions = [ question for question in questions if getattr(question, "question_text", None) == question_text ] if not matching_questions: return env.reset(seed=seed) original_questions = list(questions) try: env.questions = matching_questions return env.reset(seed=seed) finally: env.questions = original_questions def play_episode( question_text: str, model: Any, tokenizer: Any, config: Any, env: Any, episode_seed: int | None = None, ) -> dict[str, Any]: """Run one environment episode and collect completion + metadata.""" observation = _reset_for_prompt(env, question_text, seed=episode_seed) history_pairs: list[tuple[str, str]] = [] action_lines: list[str] = [] seen_actions: set[str] = set() operational_signals: list[dict[str, bool]] = [] cumulative_progress = 0.0 answer_correct = False for _ in range(config.step_budget): formatted_observation = format_observation(observation) messages = _build_messages( question_text=question_text, observation=observation, history_pairs=history_pairs, ) model_output = _generate_action_text(messages, model, tokenizer, config) action = parse_model_output(model_output) action_line = f"{action.action_type}: {action.argument}" action_key = f"{action.action_type}|{action.argument}" is_repeat = action_key in seen_actions seen_actions.add(action_key) observation = env.step(action) if action.action_type == "QUERY" and observation.reward is not None: cumulative_progress += max(0.0, float(observation.reward)) action_lines.append(action_line) history_pairs.append((formatted_observation, action_line)) signal = { "exec_ok": not bool(observation.error), "new_info": action.action_type in {"DESCRIBE", "SAMPLE"} and not bool(observation.error), "repeat": is_repeat, } operational_signals.append(signal) if action.action_type == "ANSWER": normalized_result = observation.result.strip().lower() answer_correct = normalized_result.startswith("answer submitted: correct") if observation.done: break operational_score = float( sum(1.0 for signal in operational_signals if signal["exec_ok"]) - sum(1.0 for signal in operational_signals if signal["repeat"]) ) metadata = { "episode_id": getattr(getattr(env, "state", None), "episode_id", None) or str(uuid.uuid4()), "step_count": len(action_lines), "done": bool(observation.done), "answer_correct": answer_correct, "cumulative_progress": cumulative_progress, "operational_signals": operational_signals, } completion_text = "\n".join(action_lines) return { "prompt": question_text, "completion": completion_text, "content": completion_text, "metadata": metadata, "correct": answer_correct, "progress": cumulative_progress, "operational": operational_score, } def rollout_func( prompts: list[str], model: Any, tokenizer: Any, config: Any, ) -> list[dict[str, Any]]: """Play SQLEnv episodes for a batch of prompt strings.""" env = _build_environment(config, tokenizer) rollouts: list[dict[str, Any]] = [] for idx, prompt in enumerate(prompts): episode_seed = ( None if getattr(config, "seed", None) is None else int(config.seed) + idx ) rollouts.append( play_episode( question_text=prompt, model=model, tokenizer=tokenizer, config=config, env=env, episode_seed=episode_seed, ) ) return rollouts