renpas22 commited on
Commit ·
9e7779a
1
Parent(s): ccd696b
Convert learning_rate to float explicitly
Browse files
src/reasoning/step_level_cot.py
CHANGED
|
@@ -468,7 +468,7 @@ class StepLevelCoTTrainer:
|
|
| 468 |
# Setup optimizer
|
| 469 |
optimizer = torch.optim.AdamW(
|
| 470 |
self.model.parameters(),
|
| 471 |
-
lr=learning_rate,
|
| 472 |
weight_decay=getattr(self.config, 'weight_decay', 0.01),
|
| 473 |
)
|
| 474 |
|
|
@@ -576,7 +576,7 @@ class StepLevelCoTTrainer:
|
|
| 576 |
from .prm import PRMTrainer
|
| 577 |
prm_trainer = PRMTrainer(
|
| 578 |
model=self.prm,
|
| 579 |
-
learning_rate=learning_rate,
|
| 580 |
weight_decay=getattr(self.config, 'weight_decay', 0.01),
|
| 581 |
warmup_steps=getattr(self.config, 'warmup_steps', 500),
|
| 582 |
)
|
|
|
|
| 468 |
# Setup optimizer
|
| 469 |
optimizer = torch.optim.AdamW(
|
| 470 |
self.model.parameters(),
|
| 471 |
+
lr=float(learning_rate),
|
| 472 |
weight_decay=getattr(self.config, 'weight_decay', 0.01),
|
| 473 |
)
|
| 474 |
|
|
|
|
| 576 |
from .prm import PRMTrainer
|
| 577 |
prm_trainer = PRMTrainer(
|
| 578 |
model=self.prm,
|
| 579 |
+
learning_rate=float(learning_rate),
|
| 580 |
weight_decay=getattr(self.config, 'weight_decay', 0.01),
|
| 581 |
warmup_steps=getattr(self.config, 'warmup_steps', 500),
|
| 582 |
)
|