| # Copyright (c) OpenMMLab. All rights reserved. | |
| from functools import wraps | |
| from operator import attrgetter | |
| from typing import List, Union | |
| import torch | |
| from torch.utils.checkpoint import checkpoint | |
| def wrap_forward(forward): | |
| def wrapper(*args): | |
| return checkpoint(forward, *args) | |
| return wrapper | |
| def turn_on_activation_checkpointing(model: torch.nn.Module, | |
| modules: Union[List[str], str]): | |
| if isinstance(modules, str): | |
| modules = [modules] | |
| for module_name in modules: | |
| module = attrgetter(module_name)(model) | |
| module.forward = wrap_forward(module.forward) | |