Human_parser / head_extractor /src /mmengine /runner /activation_checkpointing.py
codyshen's picture
Upload folder using huggingface_hub
6ed4a9c verified
# Copyright (c) OpenMMLab. All rights reserved.
from functools import wraps
from operator import attrgetter
from typing import List, Union
import torch
from torch.utils.checkpoint import checkpoint
def wrap_forward(forward):
@wraps(forward)
def wrapper(*args):
return checkpoint(forward, *args)
return wrapper
def turn_on_activation_checkpointing(model: torch.nn.Module,
modules: Union[List[str], str]):
if isinstance(modules, str):
modules = [modules]
for module_name in modules:
module = attrgetter(module_name)(model)
module.forward = wrap_forward(module.forward)