renpas22 commited on
Commit ·
917e40e
1
Parent(s): f8fc68a
Add **kwargs to train_prm and train_rl to accept config parameters
Browse files
src/reasoning/step_level_cot.py
CHANGED
|
@@ -353,6 +353,7 @@ class StepLevelCoTTrainer:
|
|
| 353 |
learning_rate: float = 1e-5,
|
| 354 |
save_steps: int = 500,
|
| 355 |
eval_steps: int = 500,
|
|
|
|
| 356 |
) -> None:
|
| 357 |
"""
|
| 358 |
Train Process Reward Model.
|
|
@@ -362,9 +363,12 @@ class StepLevelCoTTrainer:
|
|
| 362 |
learning_rate: Learning rate for PRM training
|
| 363 |
save_steps: Save checkpoint every N steps
|
| 364 |
eval_steps: Evaluate every N steps
|
|
|
|
| 365 |
"""
|
| 366 |
logger.info("Starting PRM training")
|
| 367 |
logger.info(f"Max steps: {max_steps}, LR: {learning_rate}")
|
|
|
|
|
|
|
| 368 |
|
| 369 |
# Load datasets using the trainer's dataset loading methods
|
| 370 |
train_dataset = self.load_step_dataset(split='train')
|
|
@@ -449,6 +453,7 @@ class StepLevelCoTTrainer:
|
|
| 449 |
learning_rate: float = 5e-6,
|
| 450 |
save_steps: int = 500,
|
| 451 |
eval_steps: int = 500,
|
|
|
|
| 452 |
) -> None:
|
| 453 |
"""
|
| 454 |
Train policy with reinforcement learning.
|
|
@@ -458,9 +463,12 @@ class StepLevelCoTTrainer:
|
|
| 458 |
learning_rate: Learning rate for RL training
|
| 459 |
save_steps: Save checkpoint every N steps
|
| 460 |
eval_steps: Evaluate every N steps
|
|
|
|
| 461 |
"""
|
| 462 |
logger.info("Starting RL training")
|
| 463 |
logger.info(f"Max steps: {max_steps}, LR: {learning_rate}")
|
|
|
|
|
|
|
| 464 |
|
| 465 |
# Load dataset using the trainer's dataset loading methods
|
| 466 |
train_dataset = self.load_step_dataset(split='train')
|
|
|
|
| 353 |
learning_rate: float = 1e-5,
|
| 354 |
save_steps: int = 500,
|
| 355 |
eval_steps: int = 500,
|
| 356 |
+
**kwargs, # Accept additional config like hidden_dim, num_layers, dropout, reward_scale
|
| 357 |
) -> None:
|
| 358 |
"""
|
| 359 |
Train Process Reward Model.
|
|
|
|
| 363 |
learning_rate: Learning rate for PRM training
|
| 364 |
save_steps: Save checkpoint every N steps
|
| 365 |
eval_steps: Evaluate every N steps
|
| 366 |
+
**kwargs: Additional PRM configuration (hidden_dim, num_layers, dropout, reward_scale)
|
| 367 |
"""
|
| 368 |
logger.info("Starting PRM training")
|
| 369 |
logger.info(f"Max steps: {max_steps}, LR: {learning_rate}")
|
| 370 |
+
if kwargs:
|
| 371 |
+
logger.info(f"PRM config: {kwargs}")
|
| 372 |
|
| 373 |
# Load datasets using the trainer's dataset loading methods
|
| 374 |
train_dataset = self.load_step_dataset(split='train')
|
|
|
|
| 453 |
learning_rate: float = 5e-6,
|
| 454 |
save_steps: int = 500,
|
| 455 |
eval_steps: int = 500,
|
| 456 |
+
**kwargs, # Accept additional config like gamma, lam, cliprange, etc.
|
| 457 |
) -> None:
|
| 458 |
"""
|
| 459 |
Train policy with reinforcement learning.
|
|
|
|
| 463 |
learning_rate: Learning rate for RL training
|
| 464 |
save_steps: Save checkpoint every N steps
|
| 465 |
eval_steps: Evaluate every N steps
|
| 466 |
+
**kwargs: Additional PPO configuration (gamma, lam, cliprange, vf_coef, ent_coef, etc.)
|
| 467 |
"""
|
| 468 |
logger.info("Starting RL training")
|
| 469 |
logger.info(f"Max steps: {max_steps}, LR: {learning_rate}")
|
| 470 |
+
if kwargs:
|
| 471 |
+
logger.info(f"PPO config: {kwargs}")
|
| 472 |
|
| 473 |
# Load dataset using the trainer's dataset loading methods
|
| 474 |
train_dataset = self.load_step_dataset(split='train')
|