| | import os |
| | import re |
| | import torch |
| | from safetensors import safe_open |
| | from safetensors.torch import save_file |
| | import hashlib |
| | from io import BytesIO |
| | import safetensors.torch |
| | from typing import Callable, Union, Optional |
| |
|
| |
|
| | re_digits = re.compile(r"\d+") |
| | re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") |
| | re_compiled = {} |
| |
|
| | suffix_conversion = { |
| | "attentions": {}, |
| | "resnets": { |
| | "conv1": "in_layers_2", |
| | "conv2": "out_layers_3", |
| | "time_emb_proj": "emb_layers_1", |
| | "conv_shortcut": "skip_connection", |
| | } |
| | } |
| |
|
| |
|
| | def convert_diffusers_name_to_compvis(key, is_sd2): |
| | def match(match_list, regex_text): |
| | regex = re_compiled.get(regex_text) |
| | if regex is None: |
| | regex = re.compile(regex_text) |
| | re_compiled[regex_text] = regex |
| |
|
| | r = re.match(regex, key) |
| | if not r: |
| | return False |
| |
|
| | match_list.clear() |
| | match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) |
| | return True |
| |
|
| | m = [] |
| |
|
| | if match(m, r"lora_unet_conv_in(.*)"): |
| | return f'diffusion_model_input_blocks_0_0{m[0]}' |
| |
|
| | if match(m, r"lora_unet_conv_out(.*)"): |
| | return f'diffusion_model_out_2{m[0]}' |
| |
|
| | if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"): |
| | return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}" |
| |
|
| | if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): |
| | suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) |
| | return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" |
| |
|
| | if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): |
| | suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) |
| | return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" |
| |
|
| | if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): |
| | suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) |
| | return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" |
| |
|
| | if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): |
| | return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" |
| |
|
| | if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): |
| | return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" |
| |
|
| | if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): |
| | if is_sd2: |
| | if 'mlp_fc1' in m[1]: |
| | return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" |
| | elif 'mlp_fc2' in m[1]: |
| | return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" |
| | else: |
| | return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" |
| |
|
| | return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" |
| |
|
| | if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"): |
| | if 'mlp_fc1' in m[1]: |
| | return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" |
| | elif 'mlp_fc2' in m[1]: |
| | return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" |
| | else: |
| | return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" |
| |
|
| | return key |
| |
|
| | def safetensors_hashes(tensors, metadata): |
| | """Precalculate the model hashes needed by sd-webui-additional-networks to |
| | save time on indexing the model later.""" |
| |
|
| | |
| | |
| | |
| | metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} |
| |
|
| | bytes = safetensors.torch.save(tensors, metadata) |
| | b = BytesIO(bytes) |
| |
|
| | model_hash = addnet_hash_safetensors(b) |
| | legacy_hash = addnet_hash_legacy(b) |
| | return model_hash, legacy_hash |
| |
|
| |
|
| | def addnet_hash_legacy(b): |
| | """Old model hash used by sd-webui-additional-networks for .safetensors format files""" |
| | m = hashlib.sha256() |
| |
|
| | b.seek(0x100000) |
| | m.update(b.read(0x10000)) |
| | return m.hexdigest()[0:8] |
| |
|
| |
|
| | def addnet_hash_safetensors(b): |
| | """New model hash used by sd-webui-additional-networks for .safetensors format files""" |
| | hash_sha256 = hashlib.sha256() |
| | blksize = 1024 * 1024 |
| |
|
| | b.seek(0) |
| | header = b.read(8) |
| | n = int.from_bytes(header, "little") |
| |
|
| | offset = n + 8 |
| | b.seek(offset) |
| | for chunk in iter(lambda: b.read(blksize), b""): |
| | hash_sha256.update(chunk) |
| |
|
| | return hash_sha256.hexdigest() |
| |
|
| |
|
| | def lbw_lora(input_, output, ratios): |
| | print("Apply LBW") |
| |
|
| | assert isinstance(input_, str) |
| | assert isinstance(output, str) |
| | assert isinstance(ratios, str) |
| | assert os.path.exists(input_), f"{input_} is not exists" |
| | assert os.path.exists(output) == False, f"{output} aleady exists" |
| |
|
| | LOAD_PATH = input_ |
| | SAVE_PATH = output |
| | RATIOS = [float(x) for x in ratios.split(",")] |
| | LAYERS = len(RATIOS) |
| | assert LAYERS in [17, 26] |
| |
|
| | BLOCKID17 = [ |
| | "BASE", "IN01", "IN02", "IN04", "IN05", "IN07", "IN08", "M00", |
| | "OUT03", "OUT04", "OUT05", "OUT06", "OUT07", "OUT08", "OUT09", "OUT10", "OUT11"] |
| | BLOCKID26 = [ |
| | "BASE", "IN00", "IN01", "IN02", "IN03", "IN04", "IN05", "IN06", "IN07", "IN08", "IN09", "IN10", "IN11", "M00", |
| | "OUT00", "OUT01", "OUT02", "OUT03", "OUT04", "OUT05", "OUT06", "OUT07", "OUT08", "OUT09", "OUT10", "OUT11"] |
| |
|
| | if LAYERS == 17: |
| | RATIO_OF_ = dict(zip(BLOCKID17, RATIOS)) |
| | if LAYERS == 26: |
| | RATIO_OF_ = dict(zip(BLOCKID26, RATIOS)) |
| | print(RATIO_OF_) |
| |
|
| | PATTERNS = [ |
| | r"^transformer_text_model_(encoder)_layers_(\d+)_.*", |
| | r"^diffusion_model_(in)put_blocks_(\d+)_.*", |
| | r"^diffusion_model_(middle)_block_(\d+)_.*", |
| | r"^diffusion_model_(out)put_blocks_(\d+)_.*"] |
| |
|
| | def replacement(match): |
| | g1 = str(match.group(1)) |
| | g2 = int(match.group(2)) |
| | assert g1 in ["encoder", "in", "middle", "out"] |
| | assert isinstance(g2, int) |
| |
|
| | if g1 == "encoder": |
| | return "BASE" |
| | if g1 == "middle": |
| | return "M00" |
| | return f"{str.upper(g1)}{g2:02}" |
| |
|
| | def compvis_name_to_blockid(compvis_name): |
| | strings = compvis_name |
| | for pattern in PATTERNS: |
| | strings = re.sub(pattern, replacement, strings) |
| | if strings != compvis_name: |
| | break |
| | assert strings != compvis_name |
| | blockid = strings |
| |
|
| | if LAYERS == 17: |
| | assert blockid in BLOCKID26, f"Incorrect layer {blockid}" |
| | assert blockid in BLOCKID17, f"{blockid} is not included in 17 layers. May be 26 layers?" |
| | if LAYERS == 26: |
| | assert blockid in BLOCKID26, f"Incorrect layer {blockid}" |
| | return blockid |
| |
|
| | with safe_open(LOAD_PATH, framework="pt", device="cpu") as f: |
| | tensors = {} |
| | for key in f.keys(): |
| | tensors[key] = f.get_tensor(key) |
| | compvis_name = convert_diffusers_name_to_compvis(key, is_sd2=False) |
| | blockid = compvis_name_to_blockid(compvis_name) |
| | if compvis_name.endswith("lora_up.weight"): |
| | tensors[key] *= RATIO_OF_[blockid] |
| | print(f"({blockid}) {compvis_name} " |
| | f"updated with factor {RATIO_OF_[blockid]}") |
| | |
| | save_file(tensors, SAVE_PATH) |
| |
|
| | print("Done") |
| |
|