File size: 766 Bytes
1146a67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
from contextlib import contextmanager


@contextmanager
def skip_model_initialization(device=torch.device("meta")):

    def register_empty_parameter(module, name, param):
        old_register_parameter(module, name, param)
        if param is not None:
            param_cls = type(module._parameters[name])
            kwargs = module._parameters[name].__dict__
            kwargs["requires_grad"] = param.requires_grad
            module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)

    old_register_parameter = torch.nn.Module.register_parameter
    torch.nn.Module.register_parameter = register_empty_parameter
    try:
        yield
    finally:
        torch.nn.Module.register_parameter = old_register_parameter