renpas22 commited on
Commit
917e40e
·
1 Parent(s): f8fc68a

Add **kwargs to train_prm and train_rl to accept config parameters

Browse files
Files changed (1) hide show
  1. src/reasoning/step_level_cot.py +8 -0
src/reasoning/step_level_cot.py CHANGED
@@ -353,6 +353,7 @@ class StepLevelCoTTrainer:
353
  learning_rate: float = 1e-5,
354
  save_steps: int = 500,
355
  eval_steps: int = 500,
 
356
  ) -> None:
357
  """
358
  Train Process Reward Model.
@@ -362,9 +363,12 @@ class StepLevelCoTTrainer:
362
  learning_rate: Learning rate for PRM training
363
  save_steps: Save checkpoint every N steps
364
  eval_steps: Evaluate every N steps
 
365
  """
366
  logger.info("Starting PRM training")
367
  logger.info(f"Max steps: {max_steps}, LR: {learning_rate}")
 
 
368
 
369
  # Load datasets using the trainer's dataset loading methods
370
  train_dataset = self.load_step_dataset(split='train')
@@ -449,6 +453,7 @@ class StepLevelCoTTrainer:
449
  learning_rate: float = 5e-6,
450
  save_steps: int = 500,
451
  eval_steps: int = 500,
 
452
  ) -> None:
453
  """
454
  Train policy with reinforcement learning.
@@ -458,9 +463,12 @@ class StepLevelCoTTrainer:
458
  learning_rate: Learning rate for RL training
459
  save_steps: Save checkpoint every N steps
460
  eval_steps: Evaluate every N steps
 
461
  """
462
  logger.info("Starting RL training")
463
  logger.info(f"Max steps: {max_steps}, LR: {learning_rate}")
 
 
464
 
465
  # Load dataset using the trainer's dataset loading methods
466
  train_dataset = self.load_step_dataset(split='train')
 
353
  learning_rate: float = 1e-5,
354
  save_steps: int = 500,
355
  eval_steps: int = 500,
356
+ **kwargs, # Accept additional config like hidden_dim, num_layers, dropout, reward_scale
357
  ) -> None:
358
  """
359
  Train Process Reward Model.
 
363
  learning_rate: Learning rate for PRM training
364
  save_steps: Save checkpoint every N steps
365
  eval_steps: Evaluate every N steps
366
+ **kwargs: Additional PRM configuration (hidden_dim, num_layers, dropout, reward_scale)
367
  """
368
  logger.info("Starting PRM training")
369
  logger.info(f"Max steps: {max_steps}, LR: {learning_rate}")
370
+ if kwargs:
371
+ logger.info(f"PRM config: {kwargs}")
372
 
373
  # Load datasets using the trainer's dataset loading methods
374
  train_dataset = self.load_step_dataset(split='train')
 
453
  learning_rate: float = 5e-6,
454
  save_steps: int = 500,
455
  eval_steps: int = 500,
456
+ **kwargs, # Accept additional config like gamma, lam, cliprange, etc.
457
  ) -> None:
458
  """
459
  Train policy with reinforcement learning.
 
463
  learning_rate: Learning rate for RL training
464
  save_steps: Save checkpoint every N steps
465
  eval_steps: Evaluate every N steps
466
+ **kwargs: Additional PPO configuration (gamma, lam, cliprange, vf_coef, ent_coef, etc.)
467
  """
468
  logger.info("Starting RL training")
469
  logger.info(f"Max steps: {max_steps}, LR: {learning_rate}")
470
+ if kwargs:
471
+ logger.info(f"PPO config: {kwargs}")
472
 
473
  # Load dataset using the trainer's dataset loading methods
474
  train_dataset = self.load_step_dataset(split='train')