File size: 664 Bytes
6ed4a9c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | # Copyright (c) OpenMMLab. All rights reserved.
from functools import wraps
from operator import attrgetter
from typing import List, Union
import torch
from torch.utils.checkpoint import checkpoint
def wrap_forward(forward):
@wraps(forward)
def wrapper(*args):
return checkpoint(forward, *args)
return wrapper
def turn_on_activation_checkpointing(model: torch.nn.Module,
modules: Union[List[str], str]):
if isinstance(modules, str):
modules = [modules]
for module_name in modules:
module = attrgetter(module_name)(model)
module.forward = wrap_forward(module.forward)
|