| from collections import namedtuple
|
|
|
| import torch
|
| from modules import devices, shared
|
|
|
| module_in_gpu = None
|
| cpu = torch.device("cpu")
|
|
|
| ModuleWithParent = namedtuple('ModuleWithParent', ['module', 'parent'], defaults=['None'])
|
|
|
| def send_everything_to_cpu():
|
| global module_in_gpu
|
|
|
| if module_in_gpu is not None:
|
| module_in_gpu.to(cpu)
|
|
|
| module_in_gpu = None
|
|
|
|
|
| def is_needed(sd_model):
|
| return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner')
|
|
|
|
|
| def apply(sd_model):
|
| enable = is_needed(sd_model)
|
| shared.parallel_processing_allowed = not enable
|
|
|
| if enable:
|
| setup_for_low_vram(sd_model, not shared.cmd_opts.lowvram)
|
| else:
|
| sd_model.lowvram = False
|
|
|
|
|
| def setup_for_low_vram(sd_model, use_medvram):
|
| if getattr(sd_model, 'lowvram', False):
|
| return
|
|
|
| sd_model.lowvram = True
|
|
|
| parents = {}
|
|
|
| def send_me_to_gpu(module, _):
|
| """send this module to GPU; send whatever tracked module was previous in GPU to CPU;
|
| we add this as forward_pre_hook to a lot of modules and this way all but one of them will
|
| be in CPU
|
| """
|
| global module_in_gpu
|
|
|
| module = parents.get(module, module)
|
|
|
| if module_in_gpu == module:
|
| return
|
|
|
| if module_in_gpu is not None:
|
| module_in_gpu.to(cpu)
|
|
|
| module.to(devices.device)
|
| module_in_gpu = module
|
|
|
|
|
|
|
|
|
|
|
| first_stage_model = sd_model.first_stage_model
|
| first_stage_model_encode = sd_model.first_stage_model.encode
|
| first_stage_model_decode = sd_model.first_stage_model.decode
|
|
|
| def first_stage_model_encode_wrap(x):
|
| send_me_to_gpu(first_stage_model, None)
|
| return first_stage_model_encode(x)
|
|
|
| def first_stage_model_decode_wrap(z):
|
| send_me_to_gpu(first_stage_model, None)
|
| return first_stage_model_decode(z)
|
|
|
| to_remain_in_cpu = [
|
| (sd_model, 'first_stage_model'),
|
| (sd_model, 'depth_model'),
|
| (sd_model, 'embedder'),
|
| (sd_model, 'model'),
|
| ]
|
|
|
| is_sdxl = hasattr(sd_model, 'conditioner')
|
| is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
|
|
|
| if hasattr(sd_model, 'medvram_fields'):
|
| to_remain_in_cpu = sd_model.medvram_fields()
|
| elif is_sdxl:
|
| to_remain_in_cpu.append((sd_model, 'conditioner'))
|
| elif is_sd2:
|
| to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
|
| else:
|
| to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))
|
|
|
|
|
| stored = []
|
| for obj, field in to_remain_in_cpu:
|
| module = getattr(obj, field, None)
|
| stored.append(module)
|
| setattr(obj, field, None)
|
|
|
|
|
| sd_model.to(devices.device)
|
|
|
|
|
| for (obj, field), module in zip(to_remain_in_cpu, stored):
|
| setattr(obj, field, module)
|
|
|
|
|
| if hasattr(sd_model, "cond_stage_model") and hasattr(sd_model.cond_stage_model, "medvram_modules"):
|
| for module in sd_model.cond_stage_model.medvram_modules():
|
| if isinstance(module, ModuleWithParent):
|
| parent = module.parent
|
| module = module.module
|
| else:
|
| parent = None
|
|
|
| if module:
|
| module.register_forward_pre_hook(send_me_to_gpu)
|
|
|
| if parent:
|
| parents[module] = parent
|
|
|
| elif is_sdxl:
|
| sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
|
| elif is_sd2:
|
| sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
|
| sd_model.cond_stage_model.model.token_embedding.register_forward_pre_hook(send_me_to_gpu)
|
| parents[sd_model.cond_stage_model.model] = sd_model.cond_stage_model
|
| parents[sd_model.cond_stage_model.model.token_embedding] = sd_model.cond_stage_model
|
| else:
|
| sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
| parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
|
|
| sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
| sd_model.first_stage_model.encode = first_stage_model_encode_wrap
|
| sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
| if getattr(sd_model, 'depth_model', None) is not None:
|
| sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
|
| if getattr(sd_model, 'embedder', None) is not None:
|
| sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
|
|
|
| if use_medvram:
|
| sd_model.model.register_forward_pre_hook(send_me_to_gpu)
|
| else:
|
| diff_model = sd_model.model.diffusion_model
|
|
|
|
|
|
|
| stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
|
| diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
|
| sd_model.model.to(devices.device)
|
| diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
|
|
|
|
|
| diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
|
| for block in diff_model.input_blocks:
|
| block.register_forward_pre_hook(send_me_to_gpu)
|
| diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
|
| for block in diff_model.output_blocks:
|
| block.register_forward_pre_hook(send_me_to_gpu)
|
|
|
|
|
| def is_enabled(sd_model):
|
| return sd_model.lowvram
|
|
|