| | """ |
| | Example demonstrating special token usage in chain-of-thought reasoning. |
| | |
| | This script shows how to: |
| | 1. Generate reasoning chains with special tokens |
| | 2. Parse structured reasoning from generated text |
| | 3. Use special tokens for PRM evaluation |
| | """ |
| |
|
| | import torch |
| | from typing import List, Dict, Tuple |
| | from src.reasoning.step_data import ReasoningStep, ReasoningChain, StepType, SPECIAL_TOKENS |
| | from src.reasoning.prm import ProcessRewardModel |
| |
|
| |
|
| | class SpecialTokenParser: |
| | """Parse generated text with special tokens into structured ReasoningChain.""" |
| | |
| | def __init__(self, tokenizer): |
| | self.tokenizer = tokenizer |
| | |
| | |
| | self.token_ids = { |
| | name: tokenizer.convert_tokens_to_ids(name) |
| | for name in SPECIAL_TOKENS |
| | } |
| | |
| | def parse_reasoning_chain(self, generated_text: str, image_path: str, prompt: str) -> ReasoningChain: |
| | """ |
| | Parse generated text with special tokens into a ReasoningChain. |
| | |
| | Args: |
| | generated_text: Generated text containing special tokens |
| | image_path: Path to the input image |
| | prompt: Original prompt |
| | |
| | Returns: |
| | ReasoningChain with parsed steps |
| | """ |
| | steps = [] |
| | |
| | |
| | if "<|reasoning_start|>" not in generated_text or "<|reasoning_end|>" not in generated_text: |
| | |
| | return ReasoningChain( |
| | chain_id="parsed_0", |
| | image_path=image_path, |
| | prompt=prompt, |
| | steps=[], |
| | final_answer=generated_text, |
| | is_correct=False |
| | ) |
| | |
| | |
| | reasoning_start = generated_text.find("<|reasoning_start|>") + len("<|reasoning_start|>") |
| | reasoning_end = generated_text.find("<|reasoning_end|>") |
| | reasoning_text = generated_text[reasoning_start:reasoning_end] |
| | |
| | |
| | answer_start = generated_text.find("<|answer_start|>") + len("<|answer_start|>") |
| | answer_end = generated_text.find("<|answer_end|>") |
| | final_answer = generated_text[answer_start:answer_end] if answer_start > 0 and answer_end > 0 else "" |
| | |
| | |
| | step_texts = reasoning_text.split("<|step_start|>") |
| | |
| | for i, step_text in enumerate(step_texts): |
| | if not step_text.strip(): |
| | continue |
| | |
| | step = self._parse_step(step_text, i) |
| | if step: |
| | steps.append(step) |
| | |
| | return ReasoningChain( |
| | chain_id=f"parsed_{hash(generated_text) % 10000}", |
| | image_path=image_path, |
| | prompt=prompt, |
| | steps=steps, |
| | final_answer=final_answer.strip(), |
| | is_correct=False |
| | ) |
| | |
| | def _parse_step(self, step_text: str, step_id: int) -> ReasoningStep: |
| | """Parse a single step from text.""" |
| | try: |
| | |
| | step_type = StepType.INFERENCE |
| | if "<|step_type|>" in step_text: |
| | type_start = step_text.find("<|step_type|>") + len("<|step_type|>") |
| | type_end = step_text.find("<|", type_start) |
| | step_type_str = step_text[type_start:type_end].strip() |
| | try: |
| | step_type = StepType(step_type_str) |
| | except ValueError: |
| | pass |
| | |
| | |
| | dependencies = [] |
| | if "<|depends_on|>" in step_text: |
| | dep_start = step_text.find("<|depends_on|>") + len("<|depends_on|>") |
| | dep_end = step_text.find("<|", dep_start) |
| | deps_str = step_text[dep_start:dep_end].strip() |
| | dependencies = [int(d) for d in deps_str.split(",") if d.strip().isdigit()] |
| | |
| | |
| | description = "" |
| | if "<|description_start|>" in step_text: |
| | desc_start = step_text.find("<|description_start|>") + len("<|description_start|>") |
| | desc_end = step_text.find("<|description_end|>") |
| | description = step_text[desc_start:desc_end].strip() |
| | |
| | |
| | confidence = 0.5 |
| | if "<|confidence_start|>" in step_text: |
| | conf_start = step_text.find("<|confidence_start|>") + len("<|confidence_start|>") |
| | conf_end = step_text.find("<|confidence_end|>") |
| | try: |
| | confidence = float(step_text[conf_start:conf_end].strip()) |
| | except ValueError: |
| | pass |
| | |
| | return ReasoningStep( |
| | step_id=step_id, |
| | step_type=step_type, |
| | description=description, |
| | confidence=confidence, |
| | dependencies=dependencies |
| | ) |
| | |
| | except Exception as e: |
| | print(f"Error parsing step: {e}") |
| | return None |
| |
|
| |
|
| | class SpecialTokenGenerator: |
| | """Generate text with special tokens for structured reasoning.""" |
| | |
| | def __init__(self, model, tokenizer, prm: ProcessRewardModel = None): |
| | self.model = model |
| | self.tokenizer = tokenizer |
| | self.prm = prm |
| | |
| | |
| | self.reasoning_start_id = tokenizer.convert_tokens_to_ids("<|reasoning_start|>") |
| | self.reasoning_end_id = tokenizer.convert_tokens_to_ids("<|reasoning_end|>") |
| | self.step_start_id = tokenizer.convert_tokens_to_ids("<|step_start|>") |
| | self.step_end_id = tokenizer.convert_tokens_to_ids("<|step_end|>") |
| | self.ki_id = tokenizer.convert_tokens_to_ids("ки") |
| | |
| | def generate_with_structure( |
| | self, |
| | prompt: str, |
| | image_features: torch.Tensor, |
| | max_steps: int = 5, |
| | temperature: float = 0.7, |
| | ) -> Tuple[str, List[float]]: |
| | """ |
| | Generate reasoning chain with enforced structure and PRM evaluation. |
| | |
| | Args: |
| | prompt: Input prompt |
| | image_features: Visual features |
| | max_steps: Maximum reasoning steps |
| | temperature: Sampling temperature |
| | |
| | Returns: |
| | (generated_text, step_rewards) |
| | """ |
| | |
| | input_ids = self.tokenizer.encode(prompt, return_tensors='pt') |
| | |
| | |
| | reasoning_start = torch.tensor([[self.reasoning_start_id]]) |
| | current_ids = torch.cat([input_ids, reasoning_start], dim=1) |
| | |
| | step_rewards = [] |
| | generated_text_parts = [prompt, " <|reasoning_start|>"] |
| | |
| | |
| | for step_num in range(max_steps): |
| | |
| | step_text, step_reward = self._generate_step( |
| | current_ids, |
| | image_features, |
| | step_num, |
| | temperature |
| | ) |
| | |
| | generated_text_parts.append(step_text) |
| | step_rewards.append(step_reward) |
| | |
| | |
| | step_ids = self.tokenizer.encode(step_text, add_special_tokens=False, return_tensors='pt') |
| | current_ids = torch.cat([current_ids, step_ids], dim=1) |
| | |
| | |
| | if step_reward < 0.3: |
| | break |
| | |
| | |
| | generated_text_parts.append(" <|reasoning_end|>") |
| | |
| | |
| | reasoning_end = torch.tensor([[self.reasoning_end_id]]) |
| | current_ids = torch.cat([current_ids, reasoning_end], dim=1) |
| | |
| | answer_text = self._generate_answer(current_ids, temperature) |
| | generated_text_parts.append(answer_text) |
| | |
| | return "".join(generated_text_parts), step_rewards |
| | |
| | def _generate_step( |
| | self, |
| | current_ids: torch.Tensor, |
| | image_features: torch.Tensor, |
| | step_num: int, |
| | temperature: float, |
| | ) -> Tuple[str, float]: |
| | """Generate a single reasoning step with structure.""" |
| | |
| | step_start = torch.tensor([[self.step_start_id]]) |
| | step_ids = torch.cat([current_ids, step_start], dim=1) |
| | |
| | parts = [" <|step_start|>"] |
| | |
| | |
| | parts.append(" <|step_type|>") |
| | type_text = self._generate_until_token(step_ids, "<|", max_new_tokens=10, temperature=temperature) |
| | parts.append(type_text) |
| | |
| | |
| | parts.append(" <|description_start|>") |
| | desc_text = self._generate_until_token(step_ids, "<|description_end|>", max_new_tokens=100, temperature=temperature) |
| | parts.append(desc_text) |
| | parts.append("<|description_end|>") |
| | |
| | |
| | parts.append(" <|confidence_start|>") |
| | conf_text = self._generate_until_token(step_ids, "<|confidence_end|>", max_new_tokens=5, temperature=temperature) |
| | parts.append(conf_text) |
| | parts.append("<|confidence_end|>") |
| | |
| | |
| | parts.append(" ки") |
| | |
| | |
| | step_reward = 0.5 |
| | if self.prm: |
| | step_text = "".join(parts) |
| | |
| | |
| | |
| | parts.append(" <|step_end|>") |
| | |
| | return "".join(parts), step_reward |
| | |
| | def _generate_until_token( |
| | self, |
| | input_ids: torch.Tensor, |
| | stop_token: str, |
| | max_new_tokens: int, |
| | temperature: float, |
| | ) -> str: |
| | """Generate text until a specific token appears.""" |
| | |
| | outputs = self.model.generate( |
| | input_ids, |
| | max_new_tokens=max_new_tokens, |
| | temperature=temperature, |
| | do_sample=True, |
| | ) |
| | |
| | generated_text = self.tokenizer.decode(outputs[0][input_ids.size(1):], skip_special_tokens=False) |
| | |
| | |
| | if stop_token in generated_text: |
| | generated_text = generated_text[:generated_text.find(stop_token)] |
| | |
| | return generated_text |
| | |
| | def _generate_answer(self, current_ids: torch.Tensor, temperature: float) -> str: |
| | """Generate final answer.""" |
| | parts = [" <|answer_start|>"] |
| | |
| | |
| | outputs = self.model.generate( |
| | current_ids, |
| | max_new_tokens=100, |
| | temperature=temperature, |
| | ) |
| | |
| | answer_text = self.tokenizer.decode(outputs[0][current_ids.size(1):], skip_special_tokens=False) |
| | |
| | |
| | if "<|answer_end|>" in answer_text: |
| | answer_text = answer_text[:answer_text.find("<|answer_end|>")] |
| | |
| | parts.append(answer_text) |
| | parts.append("<|answer_end|>") |
| | |
| | return "".join(parts) |
| |
|
| |
|
| | def example_usage(): |
| | """Example demonstrating special token usage.""" |
| | from transformers import AutoTokenizer, AutoModel |
| | |
| | |
| | model_path = "path/to/model_with_special_tokens" |
| | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
| | model = AutoModel.from_pretrained(model_path, trust_remote_code=True) |
| | |
| | |
| | generator = SpecialTokenGenerator(model, tokenizer) |
| | |
| | |
| | prompt = "How many red objects are in the image?" |
| | image_features = torch.randn(1, 768) |
| | |
| | generated_text, step_rewards = generator.generate_with_structure( |
| | prompt=prompt, |
| | image_features=image_features, |
| | max_steps=5, |
| | temperature=0.7, |
| | ) |
| | |
| | print("Generated reasoning:") |
| | print(generated_text) |
| | print(f"\nStep rewards: {step_rewards}") |
| | |
| | |
| | parser = SpecialTokenParser(tokenizer) |
| | chain = parser.parse_reasoning_chain( |
| | generated_text=generated_text, |
| | image_path="example.jpg", |
| | prompt=prompt, |
| | ) |
| | |
| | print(f"\nParsed {len(chain.steps)} steps:") |
| | for step in chain.steps: |
| | print(f" Step {step.step_id} ({step.step_type.value}): {step.description}") |
| | print(f"Final answer: {chain.final_answer}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | example_usage() |
| |
|