renpas22 commited on
Commit
9e7779a
·
1 Parent(s): ccd696b

Convert learning_rate to float explicitly

Browse files
Files changed (1) hide show
  1. src/reasoning/step_level_cot.py +2 -2
src/reasoning/step_level_cot.py CHANGED
@@ -468,7 +468,7 @@ class StepLevelCoTTrainer:
468
  # Setup optimizer
469
  optimizer = torch.optim.AdamW(
470
  self.model.parameters(),
471
- lr=learning_rate,
472
  weight_decay=getattr(self.config, 'weight_decay', 0.01),
473
  )
474
 
@@ -576,7 +576,7 @@ class StepLevelCoTTrainer:
576
  from .prm import PRMTrainer
577
  prm_trainer = PRMTrainer(
578
  model=self.prm,
579
- learning_rate=learning_rate,
580
  weight_decay=getattr(self.config, 'weight_decay', 0.01),
581
  warmup_steps=getattr(self.config, 'warmup_steps', 500),
582
  )
 
468
  # Setup optimizer
469
  optimizer = torch.optim.AdamW(
470
  self.model.parameters(),
471
+ lr=float(learning_rate),
472
  weight_decay=getattr(self.config, 'weight_decay', 0.01),
473
  )
474
 
 
576
  from .prm import PRMTrainer
577
  prm_trainer = PRMTrainer(
578
  model=self.prm,
579
+ learning_rate=float(learning_rate),
580
  weight_decay=getattr(self.config, 'weight_decay', 0.01),
581
  warmup_steps=getattr(self.config, 'warmup_steps', 500),
582
  )