| |
| |
| |
| |
| |
|
|
| import hashlib |
| import math |
| import os |
| from collections import defaultdict |
| from io import BytesIO |
| from typing import List, Optional, Type, Union |
|
|
| import safetensors.torch |
| import torch |
| import torch.utils.checkpoint |
| from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear |
| from safetensors.torch import load_file |
| from transformers import T5EncoderModel |
|
|
|
|
| class LoRAModule(torch.nn.Module): |
| """ |
| replaces forward method of the original Linear, instead of replacing the original Linear module. |
| """ |
|
|
| def __init__( |
| self, |
| lora_name, |
| org_module: torch.nn.Module, |
| multiplier=1.0, |
| lora_dim=4, |
| alpha=1, |
| dropout=None, |
| rank_dropout=None, |
| module_dropout=None, |
| ): |
| """if alpha == 0 or None, alpha is rank (no scaling).""" |
| super().__init__() |
| self.lora_name = lora_name |
|
|
| if org_module.__class__.__name__ == "Conv2d": |
| in_dim = org_module.in_channels |
| out_dim = org_module.out_channels |
| else: |
| in_dim = org_module.in_features |
| out_dim = org_module.out_features |
|
|
| self.lora_dim = lora_dim |
| if org_module.__class__.__name__ == "Conv2d": |
| kernel_size = org_module.kernel_size |
| stride = org_module.stride |
| padding = org_module.padding |
| self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) |
| self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) |
| else: |
| self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) |
| self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) |
|
|
| if type(alpha) == torch.Tensor: |
| alpha = alpha.detach().float().numpy() |
| alpha = self.lora_dim if alpha is None or alpha == 0 else alpha |
| self.scale = alpha / self.lora_dim |
| self.register_buffer("alpha", torch.tensor(alpha)) |
|
|
| |
| torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) |
| torch.nn.init.zeros_(self.lora_up.weight) |
|
|
| self.multiplier = multiplier |
| self.org_module = org_module |
| self.dropout = dropout |
| self.rank_dropout = rank_dropout |
| self.module_dropout = module_dropout |
|
|
| def apply_to(self): |
| self.org_forward = self.org_module.forward |
| self.org_module.forward = self.forward |
| del self.org_module |
|
|
| def forward(self, x, *args, **kwargs): |
| weight_dtype = x.dtype |
| org_forwarded = self.org_forward(x) |
|
|
| |
| if self.module_dropout is not None and self.training: |
| if torch.rand(1) < self.module_dropout: |
| return org_forwarded |
|
|
| lx = self.lora_down(x.to(self.lora_down.weight.dtype)) |
|
|
| |
| if self.dropout is not None and self.training: |
| lx = torch.nn.functional.dropout(lx, p=self.dropout) |
|
|
| |
| if self.rank_dropout is not None and self.training: |
| mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout |
| if len(lx.size()) == 3: |
| mask = mask.unsqueeze(1) |
| elif len(lx.size()) == 4: |
| mask = mask.unsqueeze(-1).unsqueeze(-1) |
| lx = lx * mask |
|
|
| |
| scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) |
| else: |
| scale = self.scale |
|
|
| lx = self.lora_up(lx) |
|
|
| return org_forwarded.to(weight_dtype) + lx.to(weight_dtype) * self.multiplier * scale |
|
|
|
|
| def addnet_hash_legacy(b): |
| """Old model hash used by sd-webui-additional-networks for .safetensors format files""" |
| m = hashlib.sha256() |
|
|
| b.seek(0x100000) |
| m.update(b.read(0x10000)) |
| return m.hexdigest()[0:8] |
|
|
|
|
| def addnet_hash_safetensors(b): |
| """New model hash used by sd-webui-additional-networks for .safetensors format files""" |
| hash_sha256 = hashlib.sha256() |
| blksize = 1024 * 1024 |
|
|
| b.seek(0) |
| header = b.read(8) |
| n = int.from_bytes(header, "little") |
|
|
| offset = n + 8 |
| b.seek(offset) |
| for chunk in iter(lambda: b.read(blksize), b""): |
| hash_sha256.update(chunk) |
|
|
| return hash_sha256.hexdigest() |
|
|
|
|
| def precalculate_safetensors_hashes(tensors, metadata): |
| """Precalculate the model hashes needed by sd-webui-additional-networks to |
| save time on indexing the model later.""" |
|
|
| |
| |
| |
| metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} |
|
|
| bytes = safetensors.torch.save(tensors, metadata) |
| b = BytesIO(bytes) |
|
|
| model_hash = addnet_hash_safetensors(b) |
| legacy_hash = addnet_hash_legacy(b) |
| return model_hash, legacy_hash |
|
|
|
|
| class LoRANetwork(torch.nn.Module): |
| TRANSFORMER_TARGET_REPLACE_MODULE = [ |
| "CogVideoXTransformer3DModel", "WanTransformer3DModel", \ |
| "Wan2_2Transformer3DModel", "FluxTransformer2DModel", "QwenImageTransformer2DModel" |
| ] |
| TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF", "BertEncoder", "T5SelfAttention", "T5CrossAttention"] |
| LORA_PREFIX_TRANSFORMER = "lora_unet" |
| LORA_PREFIX_TEXT_ENCODER = "lora_te" |
| def __init__( |
| self, |
| text_encoder: Union[List[T5EncoderModel], T5EncoderModel], |
| unet, |
| multiplier: float = 1.0, |
| lora_dim: int = 4, |
| alpha: float = 1, |
| dropout: Optional[float] = None, |
| module_class: Type[object] = LoRAModule, |
| skip_name: str = None, |
| varbose: Optional[bool] = False, |
| ) -> None: |
| super().__init__() |
| self.multiplier = multiplier |
|
|
| self.lora_dim = lora_dim |
| self.alpha = alpha |
| self.dropout = dropout |
|
|
| print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") |
| print(f"neuron dropout: p={self.dropout}") |
|
|
| |
| def create_modules( |
| is_unet: bool, |
| root_module: torch.nn.Module, |
| target_replace_modules: List[torch.nn.Module], |
| ) -> List[LoRAModule]: |
| prefix = ( |
| self.LORA_PREFIX_TRANSFORMER |
| if is_unet |
| else self.LORA_PREFIX_TEXT_ENCODER |
| ) |
| loras = [] |
| skipped = [] |
| for name, module in root_module.named_modules(): |
| if module.__class__.__name__ in target_replace_modules: |
| for child_name, child_module in module.named_modules(): |
| is_linear = child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear" |
| is_conv2d = child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv" |
| is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) |
| |
| if skip_name is not None and skip_name in child_name: |
| continue |
|
|
| if is_linear or is_conv2d: |
| lora_name = prefix + "." + name + "." + child_name |
| lora_name = lora_name.replace(".", "_") |
|
|
| dim = None |
| alpha = None |
|
|
| if is_linear or is_conv2d_1x1: |
| dim = self.lora_dim |
| alpha = self.alpha |
|
|
| if dim is None or dim == 0: |
| if is_linear or is_conv2d_1x1: |
| skipped.append(lora_name) |
| continue |
|
|
| lora = module_class( |
| lora_name, |
| child_module, |
| self.multiplier, |
| dim, |
| alpha, |
| dropout=dropout, |
| ) |
| loras.append(lora) |
| return loras, skipped |
|
|
| text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] |
|
|
| self.text_encoder_loras = [] |
| skipped_te = [] |
| for i, text_encoder in enumerate(text_encoders): |
| if text_encoder is not None: |
| text_encoder_loras, skipped = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) |
| self.text_encoder_loras.extend(text_encoder_loras) |
| skipped_te += skipped |
| print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") |
|
|
| self.unet_loras, skipped_un = create_modules(True, unet, LoRANetwork.TRANSFORMER_TARGET_REPLACE_MODULE) |
| print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") |
|
|
| |
| names = set() |
| for lora in self.text_encoder_loras + self.unet_loras: |
| assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" |
| names.add(lora.lora_name) |
|
|
| def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): |
| if apply_text_encoder: |
| print("enable LoRA for text encoder") |
| else: |
| self.text_encoder_loras = [] |
|
|
| if apply_unet: |
| print("enable LoRA for U-Net") |
| else: |
| self.unet_loras = [] |
|
|
| for lora in self.text_encoder_loras + self.unet_loras: |
| lora.apply_to() |
| self.add_module(lora.lora_name, lora) |
|
|
| def set_multiplier(self, multiplier): |
| self.multiplier = multiplier |
| for lora in self.text_encoder_loras + self.unet_loras: |
| lora.multiplier = self.multiplier |
|
|
| def load_weights(self, file): |
| if os.path.splitext(file)[1] == ".safetensors": |
| from safetensors.torch import load_file |
|
|
| weights_sd = load_file(file) |
| else: |
| weights_sd = torch.load(file, map_location="cpu") |
| info = self.load_state_dict(weights_sd, False) |
| return info |
|
|
| def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): |
| self.requires_grad_(True) |
| all_params = [] |
|
|
| def enumerate_params(loras): |
| params = [] |
| for lora in loras: |
| params.extend(lora.parameters()) |
| return params |
|
|
| if self.text_encoder_loras: |
| param_data = {"params": enumerate_params(self.text_encoder_loras)} |
| if text_encoder_lr is not None: |
| param_data["lr"] = text_encoder_lr |
| all_params.append(param_data) |
|
|
| if self.unet_loras: |
| param_data = {"params": enumerate_params(self.unet_loras)} |
| if unet_lr is not None: |
| param_data["lr"] = unet_lr |
| all_params.append(param_data) |
|
|
| return all_params |
|
|
| def enable_gradient_checkpointing(self): |
| pass |
|
|
| def get_trainable_params(self): |
| return self.parameters() |
|
|
| def save_weights(self, file, dtype, metadata): |
| if metadata is not None and len(metadata) == 0: |
| metadata = None |
|
|
| state_dict = self.state_dict() |
|
|
| if dtype is not None: |
| for key in list(state_dict.keys()): |
| v = state_dict[key] |
| v = v.detach().clone().to("cpu").to(dtype) |
| state_dict[key] = v |
|
|
| if os.path.splitext(file)[1] == ".safetensors": |
| from safetensors.torch import save_file |
|
|
| |
| if metadata is None: |
| metadata = {} |
| model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata) |
| metadata["sshs_model_hash"] = model_hash |
| metadata["sshs_legacy_hash"] = legacy_hash |
|
|
| save_file(state_dict, file, metadata) |
| else: |
| torch.save(state_dict, file) |
|
|
| def create_network( |
| multiplier: float, |
| network_dim: Optional[int], |
| network_alpha: Optional[float], |
| text_encoder: Union[T5EncoderModel, List[T5EncoderModel]], |
| transformer, |
| neuron_dropout: Optional[float] = None, |
| skip_name: str = None, |
| **kwargs, |
| ): |
| if network_dim is None: |
| network_dim = 4 |
| if network_alpha is None: |
| network_alpha = 1.0 |
|
|
| network = LoRANetwork( |
| text_encoder, |
| transformer, |
| multiplier=multiplier, |
| lora_dim=network_dim, |
| alpha=network_alpha, |
| dropout=neuron_dropout, |
| skip_name=skip_name, |
| varbose=True, |
| ) |
| return network |
|
|
| def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False, sub_transformer_name="transformer"): |
| LORA_PREFIX_TRANSFORMER = "lora_unet" |
| LORA_PREFIX_TEXT_ENCODER = "lora_te" |
| if state_dict is None: |
| state_dict = load_file(lora_path) |
| else: |
| state_dict = state_dict |
| updates = defaultdict(dict) |
| for key, value in state_dict.items(): |
| if "diffusion_model" in key: |
| key = key.replace("diffusion_model.", "lora_unet__") |
| key = key.replace("blocks.", "blocks_") |
| key = key.replace(".self_attn.", "_self_attn_") |
| key = key.replace(".cross_attn.", "_cross_attn_") |
| key = key.replace(".ffn.", "_ffn_") |
| if "lora_A" in key or "lora_B" in key: |
| key = "lora_unet__" + key |
| key = key.replace("blocks.", "blocks_") |
| key = key.replace(".self_attn.", "_self_attn_") |
| key = key.replace(".cross_attn.", "_cross_attn_") |
| key = key.replace(".ffn.", "_ffn_") |
| key = key.replace(".lora_A.default.", ".lora_down.") |
| key = key.replace(".lora_B.default.", ".lora_up.") |
| layer, elem = key.split('.', 1) |
| updates[layer][elem] = value |
|
|
| sequential_cpu_offload_flag = False |
| if pipeline.transformer.device == torch.device(type="meta"): |
| pipeline.remove_all_hooks() |
| sequential_cpu_offload_flag = True |
| offload_device = pipeline._offload_device |
|
|
| for layer, elems in updates.items(): |
|
|
| if "lora_te" in layer: |
| if transformer_only: |
| continue |
| else: |
| layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") |
| curr_layer = pipeline.text_encoder |
| else: |
| layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_") |
| curr_layer = getattr(pipeline, sub_transformer_name) |
|
|
| try: |
| curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:])) |
| except Exception: |
| temp_name = layer_infos.pop(0) |
| try: |
| while len(layer_infos) > -1: |
| try: |
| curr_layer = curr_layer.__getattr__(temp_name + "_" + "_".join(layer_infos)) |
| break |
| except Exception: |
| try: |
| curr_layer = curr_layer.__getattr__(temp_name) |
| if len(layer_infos) > 0: |
| temp_name = layer_infos.pop(0) |
| elif len(layer_infos) == 0: |
| break |
| except Exception: |
| if len(layer_infos) == 0: |
| print(f'Error loading layer in front search: {layer}. Try it in back search.') |
| if len(temp_name) > 0: |
| temp_name += "_" + layer_infos.pop(0) |
| else: |
| temp_name = layer_infos.pop(0) |
| except Exception: |
| if "lora_te" in layer: |
| if transformer_only: |
| continue |
| else: |
| layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") |
| curr_layer = pipeline.text_encoder |
| else: |
| layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_") |
| curr_layer = getattr(pipeline, sub_transformer_name) |
|
|
| len_layer_infos = len(layer_infos) |
| start_index = 0 if len_layer_infos >= 1 and len(layer_infos[0]) > 0 else 1 |
| end_indx = len_layer_infos |
|
|
| error_flag = False if len_layer_infos >= 1 else True |
| while start_index < len_layer_infos: |
| try: |
| if start_index >= end_indx: |
| print(f'Error loading layer in back search: {layer}') |
| error_flag = True |
| break |
| curr_layer = curr_layer.__getattr__("_".join(layer_infos[start_index:end_indx])) |
| start_index = end_indx |
| end_indx = len_layer_infos |
| except Exception: |
| end_indx -= 1 |
| if error_flag: |
| continue |
|
|
| |
| if not hasattr(curr_layer, "weight"): |
| |
| continue |
|
|
| origin_dtype = curr_layer.weight.data.dtype |
| origin_device = curr_layer.weight.data.device |
|
|
| curr_layer = curr_layer.to(device, dtype) |
| |
| if 'lora_up.weight' not in elems or 'lora_down.weight' not in elems: |
| |
| curr_layer = curr_layer.to(origin_device, origin_dtype) |
| continue |
| weight_up = elems['lora_up.weight'].to(device, dtype) |
| weight_down = elems['lora_down.weight'].to(device, dtype) |
| |
| if 'alpha' in elems.keys(): |
| alpha = elems['alpha'].item() / weight_up.shape[1] |
| else: |
| alpha = 1.0 |
|
|
| if len(weight_up.shape) == 4: |
| curr_layer.weight.data += multiplier * alpha * torch.mm( |
| weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2) |
| ).unsqueeze(2).unsqueeze(3) |
| else: |
| curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down) |
| curr_layer = curr_layer.to(origin_device, origin_dtype) |
|
|
| if sequential_cpu_offload_flag: |
| pipeline.enable_sequential_cpu_offload(device=offload_device) |
| return pipeline |
|
|
| |
| def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32, sub_transformer_name="transformer"): |
| """Unmerge state_dict in LoRANetwork from the pipeline in diffusers.""" |
| LORA_PREFIX_UNET = "lora_unet" |
| LORA_PREFIX_TEXT_ENCODER = "lora_te" |
| state_dict = load_file(lora_path) |
|
|
| updates = defaultdict(dict) |
| for key, value in state_dict.items(): |
| if "diffusion_model" in key: |
| key = key.replace("diffusion_model.", "lora_unet__") |
| key = key.replace("blocks.", "blocks_") |
| key = key.replace(".self_attn.", "_self_attn_") |
| key = key.replace(".cross_attn.", "_cross_attn_") |
| key = key.replace(".ffn.", "_ffn_") |
| if "lora_A" in key or "lora_B" in key: |
| key = "lora_unet__" + key |
| key = key.replace("blocks.", "blocks_") |
| key = key.replace(".self_attn.", "_self_attn_") |
| key = key.replace(".cross_attn.", "_cross_attn_") |
| key = key.replace(".ffn.", "_ffn_") |
| key = key.replace(".lora_A.default.", ".lora_down.") |
| key = key.replace(".lora_B.default.", ".lora_up.") |
| layer, elem = key.split('.', 1) |
| updates[layer][elem] = value |
|
|
| sequential_cpu_offload_flag = False |
| if pipeline.transformer.device == torch.device(type="meta"): |
| pipeline.remove_all_hooks() |
| sequential_cpu_offload_flag = True |
|
|
| for layer, elems in updates.items(): |
|
|
| if "lora_te" in layer: |
| layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") |
| curr_layer = pipeline.text_encoder |
| else: |
| layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_") |
| curr_layer = getattr(pipeline, sub_transformer_name) |
|
|
| try: |
| curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:])) |
| except Exception: |
| temp_name = layer_infos.pop(0) |
| try: |
| while len(layer_infos) > -1: |
| try: |
| curr_layer = curr_layer.__getattr__(temp_name + "_" + "_".join(layer_infos)) |
| break |
| except Exception: |
| try: |
| curr_layer = curr_layer.__getattr__(temp_name) |
| if len(layer_infos) > 0: |
| temp_name = layer_infos.pop(0) |
| elif len(layer_infos) == 0: |
| break |
| except Exception: |
| if len(layer_infos) == 0: |
| print(f'Error loading layer in front search: {layer}. Try it in back search.') |
| if len(temp_name) > 0: |
| temp_name += "_" + layer_infos.pop(0) |
| else: |
| temp_name = layer_infos.pop(0) |
| except Exception: |
| if "lora_te" in layer: |
| layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") |
| curr_layer = pipeline.text_encoder |
| else: |
| layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_") |
| curr_layer = getattr(pipeline, sub_transformer_name) |
| len_layer_infos = len(layer_infos) |
|
|
| start_index = 0 if len_layer_infos >= 1 and len(layer_infos[0]) > 0 else 1 |
| end_indx = len_layer_infos |
|
|
| error_flag = False if len_layer_infos >= 1 else True |
| while start_index < len_layer_infos: |
| try: |
| if start_index >= end_indx: |
| print(f'Error loading layer in back search: {layer}') |
| error_flag = True |
| break |
| curr_layer = curr_layer.__getattr__("_".join(layer_infos[start_index:end_indx])) |
| start_index = end_indx |
| end_indx = len_layer_infos |
| except Exception: |
| end_indx -= 1 |
| if error_flag: |
| continue |
|
|
| if not hasattr(curr_layer, "weight"): |
| continue |
|
|
| origin_dtype = curr_layer.weight.data.dtype |
| origin_device = curr_layer.weight.data.device |
|
|
| curr_layer = curr_layer.to(device, dtype) |
| if 'lora_up.weight' not in elems or 'lora_down.weight' not in elems: |
| curr_layer = curr_layer.to(origin_device, origin_dtype) |
| continue |
| weight_up = elems['lora_up.weight'].to(device, dtype) |
| weight_down = elems['lora_down.weight'].to(device, dtype) |
| |
| if 'alpha' in elems.keys(): |
| alpha = elems['alpha'].item() / weight_up.shape[1] |
| else: |
| alpha = 1.0 |
|
|
| if len(weight_up.shape) == 4: |
| curr_layer.weight.data -= multiplier * alpha * torch.mm( |
| weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2) |
| ).unsqueeze(2).unsqueeze(3) |
| else: |
| curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down) |
| curr_layer = curr_layer.to(origin_device, origin_dtype) |
|
|
| if sequential_cpu_offload_flag: |
| pipeline.enable_sequential_cpu_offload(device=device) |
| return pipeline |
|
|