| """Contains utility functionality to modify torch modules. |
| |
| For licensing see accompanying LICENSE file. |
| Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Any |
|
|
| from torch import nn |
|
|
| NORM_LAYER_TYPES = tuple(module_type for name, module_type in nn.__dict__.items() if "Norm" in name) |
| BATCH_NORM_LAYER_TYPES = tuple( |
| module_type for name, module_type in nn.__dict__.items() if "BatchNorm" in name |
| ) |
|
|
|
|
| def freeze_norm_layer(module: nn.Module) -> nn.Module: |
| """Freeze all normalization layers.""" |
|
|
| def set_module_eval_mode(module: nn.Module, _: Any) -> None: |
| module.eval() |
|
|
| for submodule in module.modules(): |
| if isinstance(submodule, NORM_LAYER_TYPES): |
| submodule.requires_grad_(False) |
| |
| |
| submodule.register_forward_pre_hook(set_module_eval_mode) |
|
|
| return module |
|
|