| |
| import torch |
|
|
| from detectron2.config import CfgNode |
| from detectron2.solver import build_lr_scheduler as build_d2_lr_scheduler |
|
|
| from .lr_scheduler import WarmupPolyLR |
|
|
|
|
| def build_lr_scheduler( |
| cfg: CfgNode, optimizer: torch.optim.Optimizer |
| ) -> torch.optim.lr_scheduler._LRScheduler: |
| """ |
| Build a LR scheduler from config. |
| """ |
| name = cfg.SOLVER.LR_SCHEDULER_NAME |
| if name == "WarmupPolyLR": |
| return WarmupPolyLR( |
| optimizer, |
| cfg.SOLVER.MAX_ITER, |
| warmup_factor=cfg.SOLVER.WARMUP_FACTOR, |
| warmup_iters=cfg.SOLVER.WARMUP_ITERS, |
| warmup_method=cfg.SOLVER.WARMUP_METHOD, |
| power=cfg.SOLVER.POLY_LR_POWER, |
| constant_ending=cfg.SOLVER.POLY_LR_CONSTANT_ENDING, |
| ) |
| else: |
| return build_d2_lr_scheduler(cfg, optimizer) |
|
|