dei-model / src /reasoning /step_level_cot.py
renpas22
Add HuggingFace Hub checkpoint persistence - upload and download checkpoints between jobs
3b46388
"""
Step-Level Chain of Thought Trainer for Vision-Language Models.
Integrates PRM, RL training, and inference-time scaling for fine-grained reasoning.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, Optional, Any
import torch
import yaml
from accelerate import Accelerator
from accelerate.logging import get_logger
from transformers import AutoTokenizer
from diffusers import QwenImagePipeline, QwenImageTransformer2DModel
from omegaconf import OmegaConf
from tqdm.auto import tqdm
from .step_data import StepDataset, ReasoningChain, SPECIAL_TOKENS
from .prm import ProcessRewardModel, PRMTrainer
from .rl_trainer import RLReasoningTrainer, RLConfig
from .inference_scaling import InferenceTimeScaling, InferenceConfig
logger = get_logger(__name__, log_level="INFO")
class StepLevelCoTTrainer:
"""
Complete training pipeline for step-level chain-of-thought reasoning.
Training phases:
1. Supervised Fine-tuning (SFT) on reasoning chains
2. PRM training on step-level rewards
3. RL fine-tuning with PRM rewards
4. Inference-time scaling evaluation
"""
def __init__(
self,
config_path: str,
device: str = "cuda",
):
"""
Initialize trainer.
Args:
config_path: Path to training configuration YAML
device: Device for training
"""
self.device = device
# Load configuration
with open(config_path, 'r') as f:
config_dict = yaml.safe_load(f)
self.config = OmegaConf.create(config_dict)
# Initialize accelerator
self.accelerator = Accelerator(
gradient_accumulation_steps=self.config.gradient_accumulation_steps,
mixed_precision=self.config.mixed_precision,
log_with=self.config.report_to,
project_dir=self.config.output_dir,
)
# Load model and tokenizer
self._load_model()
# Initialize PRM
self.prm = ProcessRewardModel(
vision_dim=self.config.vision_dim,
text_dim=self.config.text_dim,
hidden_dim=self.config.prm_hidden_dim,
num_heads=self.config.prm_num_heads,
dropout=self.config.prm_dropout,
max_steps=self.config.max_reasoning_steps,
)
# Initialize RL trainer with explicit type conversion
rl_config = RLConfig(
clip_epsilon=float(self.config.ppo_clip_epsilon),
value_loss_coef=float(self.config.value_loss_coef),
entropy_coef=float(self.config.entropy_coef),
gamma=float(self.config.gamma),
gae_lambda=float(self.config.gae_lambda),
learning_rate=float(self.config.rl_learning_rate),
batch_size=int(self.config.train_batch_size),
use_prm_rewards=True,
normalize_rewards=True,
)
self.rl_trainer = RLReasoningTrainer(
policy_model=self.model,
prm_model=self.prm,
config=rl_config,
device=device,
)
# Initialize inference scaler
# Ensure numeric types are properly converted
try:
num_samples = int(self.config.num_inference_samples)
temperature = float(self.config.inference_temperature)
aggregation = str(self.config.aggregation_method)
inference_config = InferenceConfig(
num_samples=num_samples,
temperature=temperature,
aggregation=aggregation,
use_prm_scores=True,
)
self.inference_scaler = InferenceTimeScaling(
model=self.model,
prm=self.prm,
config=inference_config,
device=device,
)
except Exception as e:
logger.error(f"Failed to initialize inference scaler: {e}")
logger.error(f"Config values - num_samples: {self.config.num_inference_samples} (type: {type(self.config.num_inference_samples)})")
logger.error(f"Config values - temperature: {self.config.inference_temperature} (type: {type(self.config.inference_temperature)})")
logger.error(f"Config values - aggregation: {self.config.aggregation_method} (type: {type(self.config.aggregation_method)})")
raise
logger.info("StepLevelCoTTrainer initialized")
def _load_model(self) -> None:
"""Load vision-language model (supports both VLM and diffusion models)."""
logger.info(f"Loading model: {self.config.pretrained_model_name_or_path}")
model_name = self.config.pretrained_model_name_or_path.lower()
# Detect model type
is_vlm = any(x in model_name for x in ["qwen2", "qwen-vl", "internvl", "llava", "cogvlm"])
is_diffusion = any(x in model_name for x in ["flux", "qwen-image", "stable-diffusion", "sd"])
if is_vlm:
# Load as Vision-Language Model
from transformers import AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig
logger.info("Loading as Vision-Language Model...")
# Prepare quantization config if enabled
quantization_config = None
if self.config.get("quantize", False):
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
)
# Load processor and model
self.tokenizer = AutoProcessor.from_pretrained(
self.config.pretrained_model_name_or_path,
trust_remote_code=True
)
torch_dtype = torch.float16 if self.config.mixed_precision == "fp16" else torch.bfloat16
self.model = AutoModelForVision2Seq.from_pretrained(
self.config.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
quantization_config=quantization_config,
trust_remote_code=True,
device_map="auto"
)
# For VLMs, model itself is the main component
self.pipeline = None
self.vae = None
self.text_encoder = None
elif is_diffusion:
# Load as Diffusion Model
logger.info("Loading as Diffusion Model...")
self.pipeline = QwenImagePipeline.from_pretrained(
self.config.pretrained_model_name_or_path,
torch_dtype=torch.float16 if self.config.mixed_precision == "fp16" else torch.float32,
)
self.model = self.pipeline.transformer
self.vae = self.pipeline.vae
self.text_encoder = self.pipeline.text_encoder
self.tokenizer = self.pipeline.tokenizer
else:
# Try loading as VLM by default
logger.warning(f"Unknown model type, attempting to load as VLM...")
from transformers import AutoProcessor, AutoModelForVision2Seq
self.tokenizer = AutoProcessor.from_pretrained(
self.config.pretrained_model_name_or_path,
trust_remote_code=True
)
self.model = AutoModelForVision2Seq.from_pretrained(
self.config.pretrained_model_name_or_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto"
)
self.pipeline = None
self.vae = None
self.text_encoder = None
# Add special tokens for chain-of-thought reasoning
self._add_special_tokens()
# Enable gradient checkpointing if available
if self.config.gradient_checkpointing:
if hasattr(self.model, 'enable_gradient_checkpointing'):
self.model.enable_gradient_checkpointing()
elif hasattr(self.model, 'gradient_checkpointing_enable'):
self.model.gradient_checkpointing_enable()
else:
logger.warning("Gradient checkpointing requested but not available for this model")
logger.info("Model loaded successfully")
def _add_special_tokens(self) -> None:
"""Add special tokens for chain-of-thought reasoning."""
from .step_data import SPECIAL_TOKENS, SPECIAL_TOKEN_DESCRIPTIONS
logger.info(f"Adding {len(SPECIAL_TOKENS)} special tokens for chain-of-thought reasoning...")
# Get actual tokenizer (handle both Tokenizer and Processor objects)
if hasattr(self.tokenizer, 'tokenizer'):
# This is a processor, extract the tokenizer
actual_tokenizer = self.tokenizer.tokenizer
else:
# This is already a tokenizer
actual_tokenizer = self.tokenizer
# Add special tokens to tokenizer
num_added_toks = actual_tokenizer.add_special_tokens(
{"additional_special_tokens": SPECIAL_TOKENS}
)
if num_added_toks > 0:
logger.info(f"Added {num_added_toks} new tokens. Resizing model embeddings...")
# Get current embedding size
old_size = self.model.get_input_embeddings().weight.size(0)
# Resize token embeddings (use actual tokenizer's vocab size)
new_size = len(actual_tokenizer)
self.model.resize_token_embeddings(new_size)
# Initialize new token embeddings with semantic initialization
self._initialize_new_token_embeddings(num_added_toks, SPECIAL_TOKEN_DESCRIPTIONS)
logger.info(f"Resized embeddings from {old_size} to {new_size}")
else:
logger.info("Special tokens already present in tokenizer.")
def _initialize_new_token_embeddings(self, num_added_toks: int, descriptions: list) -> None:
"""Initialize newly added token embeddings with semantic information."""
# Get actual tokenizer (handle both Tokenizer and Processor objects)
if hasattr(self.tokenizer, 'tokenizer'):
actual_tokenizer = self.tokenizer.tokenizer
else:
actual_tokenizer = self.tokenizer
with torch.no_grad():
input_embeddings = self.model.get_input_embeddings()
# Check if model has language model head
has_lm_head = hasattr(self.model, 'lm_head')
if has_lm_head:
lm_head = self.model.lm_head
for i, description in enumerate(reversed(descriptions), start=1):
# Tokenize the description using actual tokenizer
tokenized = actual_tokenizer.tokenize(description)
if not tokenized:
continue
tokenized_ids = actual_tokenizer.convert_tokens_to_ids(tokenized)
# Average embeddings of description tokens
new_embedding = input_embeddings.weight[tokenized_ids].mean(dim=0)
input_embeddings.weight[-i, :] = new_embedding.clone().detach()
# Initialize lm_head if it exists
if has_lm_head:
new_lm_head = lm_head.weight[tokenized_ids].mean(dim=0)
lm_head.weight[-i, :] = new_lm_head.clone().detach()
logger.info(f"Initialized {num_added_toks} new token embeddings with semantic information.")
def load_step_dataset(
self,
data_dir: str,
split: str = "train",
) -> StepDataset:
"""
Load step-level reasoning dataset.
Args:
data_dir: Directory containing reasoning chain JSON files
split: Dataset split (train/val/test)
Returns:
StepDataset instance
"""
data_path = Path(data_dir) / split
logger.info(f"Loading dataset from {data_path}")
dataset = StepDataset.from_json_dir(
str(data_path),
tokenizer=self.tokenizer,
max_steps=self.config.max_reasoning_steps,
include_visual_features=True,
filter_incorrect=False, # Include all chains for learning
)
stats = dataset.get_statistics()
logger.info(f"Dataset loaded: {stats}")
return dataset
def train_sft(
self,
max_steps: int = 10000,
learning_rate: float = 2e-5,
save_steps: int = 1000,
eval_steps: int = 500,
) -> None:
"""
Supervised Fine-Tuning phase for the base model.
Args:
max_steps: Maximum training steps
learning_rate: Learning rate for SFT
save_steps: Save checkpoint every N steps
eval_steps: Evaluate every N steps
"""
logger.info("Starting Supervised Fine-Tuning (SFT)")
logger.info(f"Max steps: {max_steps}, LR: {learning_rate}")
# Import dataset loaders
from .flux_reason_dataset import FLUXReasonDataset
from .finevision_dataset import FineVisionDataset
from torch.utils.data import DataLoader
# Get batch size with default
batch_size = getattr(self.config, 'train_batch_size', 1)
# Load FLUX-Reason-6M dataset (70% weight)
logger.info("Loading FLUX-Reason-6M dataset...")
flux_dataset = FLUXReasonDataset(
dataset_name="LucasFang/FLUX-Reason-6M",
split="train",
streaming=True,
max_samples=int(max_steps * batch_size * 0.7),
)
# Load FineVision dataset (30% weight)
logger.info("Loading FineVision dataset...")
finevision_dataset = FineVisionDataset(
dataset_name="HuggingFaceM4/FineVision",
subset="llavar_gpt4_20k", # Use LLaVA subset as default
split="train",
streaming=True,
max_samples=int(max_steps * batch_size * 0.3),
generate_synthetic_reasoning=True,
num_synthetic_steps=3,
)
# Collate function to convert ReasoningChain to model inputs
def collate_fn(batch):
# Filter out None samples
batch = [b for b in batch if b is not None]
if not batch:
# Return empty batch structure instead of None
# This prevents Accelerate DataLoader from crashing
return {
'input_ids': torch.tensor([[]]),
'attention_mask': torch.tensor([[]]),
'labels': torch.tensor([[]]),
}
# Extract text and images
texts = []
images = []
for chain in batch:
# Format: question + reasoning steps + answer
text = f"Question: {chain.prompt}\n"
text += "<|reasoning_start|>\n"
for step in chain.steps:
text += f"<|step_start|>{step.description}<|step_end|>\n"
text += "<|reasoning_end|>\n"
text += f"Answer: {chain.final_answer}"
texts.append(text)
# Handle missing images - append None which the processor should handle
if hasattr(chain, 'image'):
images.append(chain.image if chain.image is not None else None)
else:
images.append(None)
# Tokenize with processor
if hasattr(self.tokenizer, 'tokenizer'):
actual_tokenizer = self.tokenizer.tokenizer
else:
actual_tokenizer = self.tokenizer
# Use processor for vision-language models
if hasattr(self.tokenizer, '__call__'):
# Filter out None images - if all are None, pass None instead of list
valid_images = [img for img in images if img is not None]
images_param = valid_images if valid_images else None
inputs = self.tokenizer(
text=texts,
images=images_param,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048,
)
else:
inputs = actual_tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048,
)
# Set labels (same as input_ids for causal LM)
inputs['labels'] = inputs['input_ids'].clone()
return inputs
# Interleave datasets with weighted sampling
from itertools import cycle, islice
def interleave_datasets(flux_ds, fv_ds, flux_weight=0.7):
flux_iter = iter(flux_ds)
fv_iter = iter(fv_ds)
while True:
# Sample from FLUX 70% of the time
if torch.rand(1).item() < flux_weight:
try:
yield next(flux_iter)
except StopIteration:
flux_iter = iter(flux_ds)
yield next(flux_iter)
else:
try:
yield next(fv_iter)
except StopIteration:
fv_iter = iter(fv_ds)
yield next(fv_iter)
# Create interleaved dataset
from torch.utils.data import IterableDataset
class InterleavedDataset(IterableDataset):
def __init__(self, ds1, ds2, weight1=0.7):
self.ds1 = ds1
self.ds2 = ds2
self.weight1 = weight1
def __iter__(self):
return interleave_datasets(self.ds1, self.ds2, self.weight1)
mixed_dataset = InterleavedDataset(flux_dataset, finevision_dataset, 0.7)
# Create dataloader
train_dataloader = DataLoader(
mixed_dataset,
batch_size=batch_size,
collate_fn=collate_fn,
num_workers=0,
)
# Setup optimizer
optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=float(learning_rate),
weight_decay=getattr(self.config, 'weight_decay', 0.01),
)
# Check for existing checkpoint to resume from
resume_step = 0
checkpoint_path = self._find_latest_checkpoint("sft")
if checkpoint_path:
try:
resume_step = self._load_checkpoint(checkpoint_path, "sft")
logger.info(f"Resuming SFT training from step {resume_step}")
except Exception as e:
logger.warning(f"Failed to load checkpoint: {e}. Starting from scratch.")
resume_step = 0
# Prepare with accelerator
self.model, optimizer, train_dataloader = self.accelerator.prepare(
self.model, optimizer, train_dataloader
)
# Training loop
self.model.train()
global_step = resume_step
total_loss = 0
from tqdm import tqdm
progress_bar = tqdm(
total=max_steps,
initial=resume_step,
desc="SFT Training",
disable=not self.accelerator.is_local_main_process,
)
for batch in train_dataloader:
if batch is None:
continue
if global_step >= max_steps:
break
# Forward pass
outputs = self.model(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
labels=batch['labels'],
)
loss = outputs.loss
# Backward pass
self.accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
global_step += 1
progress_bar.update(1)
progress_bar.set_postfix({'loss': loss.item(), 'avg_loss': total_loss / global_step})
# Save checkpoint
if global_step % save_steps == 0:
logger.info(f"Saving checkpoint at step {global_step}")
self._save_checkpoint(global_step, "sft")
# Logging
logging_steps = getattr(self.config, 'logging_steps', 100)
if global_step % logging_steps == 0:
logger.info(f"Step {global_step}/{max_steps} - Loss: {loss.item():.4f}")
progress_bar.close()
logger.info(f"SFT training completed - Avg loss: {total_loss / global_step:.4f}")
def train_prm(
self,
max_steps: int = 5000,
learning_rate: float = 1e-5,
save_steps: int = 500,
eval_steps: int = 500,
**kwargs, # Accept additional config like hidden_dim, num_layers, dropout, reward_scale
) -> None:
"""
Train Process Reward Model.
Args:
max_steps: Maximum training steps
learning_rate: Learning rate for PRM training
save_steps: Save checkpoint every N steps
eval_steps: Evaluate every N steps
**kwargs: Additional PRM configuration (hidden_dim, num_layers, dropout, reward_scale)
"""
logger.info("Starting PRM training")
logger.info(f"Max steps: {max_steps}, LR: {learning_rate}")
if kwargs:
logger.info(f"PRM config: {kwargs}")
# Import dataset loaders
from .flux_reason_dataset import FLUXReasonDataset
from torch.utils.data import DataLoader
# Load FLUX-Reason-6M for PRM training (has reasoning chains)
logger.info("Loading FLUX-Reason-6M for PRM training...")
batch_size = getattr(self.config, 'train_batch_size', 1)
train_dataset = FLUXReasonDataset(
dataset_name="LucasFang/FLUX-Reason-6M",
split="train",
streaming=True,
max_samples=max_steps * batch_size,
)
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
)
# Setup PRM trainer
from .prm import PRMTrainer
prm_trainer = PRMTrainer(
model=self.prm,
learning_rate=float(learning_rate),
weight_decay=getattr(self.config, 'weight_decay', 0.01),
warmup_steps=getattr(self.config, 'warmup_steps', 500),
)
# Prepare with accelerator
prm_trainer.model, prm_trainer.optimizer, train_dataloader = self.accelerator.prepare(
prm_trainer.model, prm_trainer.optimizer, train_dataloader
)
# Training loop
prm_trainer.model.train()
global_step = 0
total_loss = 0
from tqdm import tqdm
progress_bar = tqdm(
total=max_steps,
desc="PRM Training",
disable=not self.accelerator.is_local_main_process,
)
for batch in train_dataloader:
if global_step >= max_steps:
break
# Extract features for PRM
# PRM scores individual reasoning steps
try:
# Get vision features and step embeddings from batch
vision_features = batch.get('image_features', batch.get('pixel_values'))
step_texts = batch.get('reasoning_steps', [])
# For now, use a simple reward signal based on correct/incorrect
target_rewards = torch.ones(len(step_texts), dtype=torch.float32)
# Train step
metrics = prm_trainer.train_step(
vision_features,
step_texts,
target_rewards,
)
total_loss += metrics['loss']
global_step += 1
progress_bar.update(1)
progress_bar.set_postfix({'loss': metrics['loss'], 'avg_loss': total_loss / global_step})
# Save checkpoint
if global_step % save_steps == 0:
logger.info(f"Saving PRM checkpoint at step {global_step}")
self._save_checkpoint(global_step, "prm")
except Exception as e:
logger.warning(f"PRM training step failed: {e}, skipping batch")
continue
progress_bar.close()
logger.info(f"PRM training completed - Avg loss: {total_loss / max(global_step, 1):.4f}")
def train_rl(
self,
max_steps: int = 8000,
learning_rate: float = 5e-6,
save_steps: int = 500,
eval_steps: int = 500,
**kwargs, # Accept additional config like gamma, lam, cliprange, etc.
) -> None:
"""
Train policy with reinforcement learning.
Args:
max_steps: Maximum training steps
learning_rate: Learning rate for RL training
save_steps: Save checkpoint every N steps
eval_steps: Evaluate every N steps
**kwargs: Additional PPO configuration (gamma, lam, cliprange, vf_coef, ent_coef, etc.)
"""
logger.info("Starting RL training")
logger.info(f"Max steps: {max_steps}, LR: {learning_rate}")
if kwargs:
logger.info(f"PPO config: {kwargs}")
# Import dataset loaders
from .flux_reason_dataset import FLUXReasonDataset
from torch.utils.data import DataLoader
# Get batch size with default
batch_size = getattr(self.config, 'train_batch_size', 1)
# Load dataset for RL training
logger.info("Loading FLUX-Reason-6M for RL training...")
train_dataset = FLUXReasonDataset(
dataset_name="LucasFang/FLUX-Reason-6M",
split="train",
streaming=True,
max_samples=max_steps * batch_size,
)
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
)
# Training loop with PPO
global_step = 0
total_reward = 0
from tqdm import tqdm
progress_bar = tqdm(
total=max_steps,
desc="RL Training",
disable=not self.accelerator.is_local_main_process,
)
for batch in train_dataloader:
if global_step >= max_steps:
break
try:
# Generate reasoning chains with current policy
# Score with PRM
# Update policy with PPO
# For now, use the RL trainer
self.rl_trainer.policy.train()
# Simple training step (actual PPO would be more complex)
vision_features = batch.get('image_features', batch.get('pixel_values'))
# Generate and score chains
# This is simplified - full PPO would involve:
# 1. Generate multiple chains
# 2. Score with PRM
# 3. Compute advantages
# 4. Update policy
# Placeholder reward (would come from PRM)
reward = 1.0
total_reward += reward
global_step += 1
progress_bar.update(1)
progress_bar.set_postfix({
'reward': reward,
'avg_reward': total_reward / global_step
})
# Save checkpoint
if global_step % save_steps == 0:
logger.info(f"Saving RL checkpoint at step {global_step}")
self._save_checkpoint(global_step, "rl")
except Exception as e:
logger.warning(f"RL training step failed: {e}, skipping batch")
continue
progress_bar.close()
logger.info(f"RL training completed - Avg reward: {total_reward / max(global_step, 1):.4f}")
def evaluate_inference_scaling(
self,
test_dataset: StepDataset,
num_samples: int = 10,
) -> Dict[str, Any]:
"""
Evaluate inference-time scaling performance.
Args:
test_dataset: Test dataset
num_samples: Number of test samples
Returns:
Evaluation metrics
"""
logger.info("Evaluating inference-time scaling")
results = []
for idx in range(min(num_samples, len(test_dataset))):
sample = test_dataset[idx]
# Run inference with scaling
best_chain, all_chains = self.inference_scaler(
image_features=sample['visual_features'][0],
prompt_ids=sample['prompt_ids'],
prompt_mask=sample['prompt_mask'],
image_path=sample['image_path'],
prompt_text="", # Would decode from IDs
)
results.append({
'chain_id': sample['chain_id'],
'best_reward': best_chain.total_reward,
'num_steps': len(best_chain.steps),
'avg_reward_all': sum(c.total_reward for c in all_chains) / len(all_chains),
'std_reward_all': torch.tensor([c.total_reward for c in all_chains]).std().item(),
})
# Aggregate metrics
metrics = {
'avg_best_reward': sum(r['best_reward'] for r in results) / len(results),
'avg_num_steps': sum(r['num_steps'] for r in results) / len(results),
'improvement_over_mean': sum(
(r['best_reward'] - r['avg_reward_all']) / (r['std_reward_all'] + 1e-8)
for r in results
) / len(results),
}
logger.info(f"Inference scaling metrics: {metrics}")
return metrics
def _evaluate_prm(
self,
prm_trainer: PRMTrainer,
val_dataset: StepDataset,
) -> Dict[str, float]:
"""Evaluate PRM on validation set."""
val_dataloader = torch.utils.data.DataLoader(
val_dataset,
batch_size=self.config.train_batch_size,
shuffle=False,
)
total_metrics = {'mse': 0, 'mae': 0, 'correlation': 0}
for batch in val_dataloader:
vision_features = batch['visual_features']
step_embeddings = batch['step_descriptions']
target_rewards = batch['step_rewards']
metrics = prm_trainer.evaluate(
vision_features,
step_embeddings,
target_rewards,
)
for key in total_metrics:
total_metrics[key] += metrics[key]
# Average
for key in total_metrics:
total_metrics[key] /= len(val_dataloader)
return total_metrics
def _save_checkpoint(self, step: int, phase: str) -> None:
"""Save checkpoint for a specific training phase and upload to Hub."""
save_path = Path(self.config.output_dir) / f"checkpoint-{phase}-{step}"
save_path.mkdir(parents=True, exist_ok=True)
if phase == "sft":
# Save the main model
self.accelerator.unwrap_model(self.model).save_pretrained(save_path)
if hasattr(self, 'tokenizer'):
self.tokenizer.save_pretrained(save_path)
# Upload to HuggingFace Hub for persistence
try:
from huggingface_hub import HfApi
api = HfApi()
repo_id = getattr(self.config, 'hub_model_id', 'Mulebot/qwen-cot-checkpoints')
logger.info(f"Uploading checkpoint to {repo_id}...")
api.upload_folder(
folder_path=str(save_path),
repo_id=repo_id,
path_in_repo=f"checkpoint-{phase}-{step}",
repo_type="model",
commit_message=f"Save {phase} checkpoint at step {step}",
)
logger.info(f"Checkpoint uploaded to hub: {repo_id}/checkpoint-{phase}-{step}")
except Exception as e:
logger.warning(f"Failed to upload checkpoint to hub: {e}")
elif phase == "prm":
# Save PRM model
torch.save(
self.accelerator.unwrap_model(self.prm).state_dict(),
save_path / "prm_model.pt",
)
elif phase == "rl":
# Save RL policy
torch.save(
self.accelerator.unwrap_model(self.rl_trainer.policy).state_dict(),
save_path / "rl_policy.pt",
)
logger.info(f"Checkpoint saved to {save_path}")
def _find_latest_checkpoint(self, phase: str) -> Optional[Path]:
"""Find the latest checkpoint for a specific phase from Hub or local."""
# First, try to find checkpoints on HuggingFace Hub
try:
from huggingface_hub import HfApi, hf_hub_download
api = HfApi()
repo_id = getattr(self.config, 'hub_model_id', 'Mulebot/qwen-cot-checkpoints')
logger.info(f"Checking for checkpoints on {repo_id}...")
# List all files in the repo
try:
files = api.list_repo_files(repo_id, repo_type="model")
# Find checkpoint directories
checkpoint_dirs = set()
for file in files:
if f"checkpoint-{phase}-" in file:
# Extract the checkpoint directory name
parts = file.split("/")
if len(parts) > 0:
checkpoint_dirs.add(parts[0])
if checkpoint_dirs:
# Extract step numbers
checkpoint_steps = []
for ckpt_dir in checkpoint_dirs:
try:
step = int(ckpt_dir.split("-")[-1])
checkpoint_steps.append((step, ckpt_dir))
except ValueError:
continue
if checkpoint_steps:
# Get the latest checkpoint
latest_step, latest_dir = max(checkpoint_steps, key=lambda x: x[0])
logger.info(f"Found checkpoint on hub at step {latest_step}: {latest_dir}")
# Download the checkpoint
local_dir = Path(self.config.output_dir) / latest_dir
local_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Downloading checkpoint from hub...")
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=repo_id,
repo_type="model",
allow_patterns=f"{latest_dir}/*",
local_dir=self.config.output_dir,
local_dir_use_symlinks=False,
)
logger.info(f"Checkpoint downloaded to {local_dir}")
return local_dir
except Exception as e:
logger.info(f"No checkpoints found on hub or error accessing: {e}")
except ImportError:
logger.warning("huggingface_hub not available, skipping hub checkpoint check")
# Fall back to local checkpoints
checkpoint_dir = Path(self.config.output_dir)
if not checkpoint_dir.exists():
return None
# Find all checkpoints for this phase
checkpoints = list(checkpoint_dir.glob(f"checkpoint-{phase}-*"))
if not checkpoints:
return None
# Extract step numbers and find the latest
checkpoint_steps = []
for ckpt in checkpoints:
try:
step = int(ckpt.name.split("-")[-1])
checkpoint_steps.append((step, ckpt))
except ValueError:
continue
if not checkpoint_steps:
return None
# Return checkpoint with highest step number
latest_step, latest_path = max(checkpoint_steps, key=lambda x: x[0])
logger.info(f"Found local checkpoint at step {latest_step}: {latest_path}")
return latest_path
def _load_checkpoint(self, checkpoint_path: Path, phase: str) -> int:
"""Load checkpoint and return the step number."""
if phase == "sft":
# Load model from checkpoint
from transformers import AutoModelForVision2Seq
logger.info(f"Loading SFT checkpoint from {checkpoint_path}")
# Load the checkpoint model
checkpoint_model = AutoModelForVision2Seq.from_pretrained(
checkpoint_path,
torch_dtype=torch.bfloat16,
device_map=self.accelerator.device,
)
# Copy weights to current model
self.model.load_state_dict(checkpoint_model.state_dict())
del checkpoint_model
elif phase == "prm":
logger.info(f"Loading PRM checkpoint from {checkpoint_path}")
self.prm.load_state_dict(
torch.load(checkpoint_path / "prm_model.pt", map_location=self.accelerator.device)
)
elif phase == "rl":
logger.info(f"Loading RL checkpoint from {checkpoint_path}")
self.rl_trainer.policy.load_state_dict(
torch.load(checkpoint_path / "rl_policy.pt", map_location=self.accelerator.device)
)
# Extract step number from checkpoint path
step = int(checkpoint_path.name.split("-")[-1])
logger.info(f"Resumed from checkpoint at step {step}")
return step
def _save_prm(self, epoch: int) -> None:
"""Save PRM checkpoint."""
save_path = Path(self.config.output_dir) / "prm" / f"epoch_{epoch}"
save_path.mkdir(parents=True, exist_ok=True)
torch.save(
self.prm.state_dict(),
save_path / "prm_model.pt",
)
logger.info(f"PRM saved to {save_path}")
def save_checkpoint(self, step: int) -> None:
"""Save complete training checkpoint."""
save_path = Path(self.config.output_dir) / f"checkpoint-{step}"
save_path.mkdir(parents=True, exist_ok=True)
# Save policy
self.accelerator.save_state(str(save_path / "policy"))
# Save PRM
torch.save(self.prm.state_dict(), save_path / "prm.pt")
# Save RL trainer state
self.rl_trainer.save_checkpoint(str(save_path / "rl_trainer.pt"))
logger.info(f"Checkpoint saved to {save_path}")
def main():
"""Main training entry point."""
parser = argparse.ArgumentParser(description="Train step-level CoT reasoning for vision models")
parser.add_argument("--config", type=str, required=True, help="Path to config YAML")
parser.add_argument("--phase", type=str, choices=["prm", "rl", "inference", "all"], default="all")
parser.add_argument("--data_dir", type=str, required=True, help="Path to reasoning chain data")
args = parser.parse_args()
# Initialize trainer
trainer = StepLevelCoTTrainer(
config_path=args.config,
device="cuda" if torch.cuda.is_available() else "cpu",
)
# Load datasets
train_dataset = trainer.load_step_dataset(args.data_dir, split="train")
val_dataset = trainer.load_step_dataset(args.data_dir, split="val")
test_dataset = trainer.load_step_dataset(args.data_dir, split="test")
# Run training phases
if args.phase in ["prm", "all"]:
trainer.train_prm(train_dataset, val_dataset)
if args.phase in ["rl", "all"]:
trainer.train_rl(train_dataset)
if args.phase in ["inference", "all"]:
trainer.evaluate_inference_scaling(test_dataset)
if __name__ == "__main__":
main()