| import re |
| from sklearn.linear_model import PassiveAggressiveClassifier |
| import torch |
| import math |
| import os |
| import gc |
| import gradio as gr |
| from torchmetrics import Precision |
| import modules.shared as shared |
| import gc |
| from safetensors.torch import load_file, save_file |
| from typing import List |
| from tqdm import tqdm |
| from modules import sd_models,scripts |
| from scripts.mergers.model_util import load_models_from_stable_diffusion_checkpoint,filenamecutter,savemodel |
| from modules.ui import create_refresh_button |
|
|
| def on_ui_tabs(): |
| import lora |
| sml_path_root = scripts.basedir() |
| LWEIGHTSPRESETS="\ |
| NONE:0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0\n\ |
| ALL:1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1\n\ |
| INS:1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0\n\ |
| IND:1,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0\n\ |
| INALL:1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0\n\ |
| MIDD:1,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0\n\ |
| OUTD:1,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0\n\ |
| OUTS:1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1\n\ |
| OUTALL:1,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1\n\ |
| ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5" |
| sml_filepath = os.path.join(sml_path_root,"scripts", "lbwpresets.txt") |
| sml_lbwpresets="" |
| try: |
| with open(sml_filepath,encoding="utf-8") as f: |
| sml_lbwpresets = f.read() |
| except OSError as e: |
| sml_lbwpresets=LWEIGHTSPRESETS |
|
|
| with gr.Blocks(analytics_enabled=False) : |
| sml_submit_result = gr.Textbox(label="Message") |
| with gr.Row().style(equal_height=False): |
| sml_cpmerge = gr.Button(elem_id="model_merger_merge", value="Merge to Checkpoint",variant='primary') |
| sml_makelora = gr.Button(elem_id="model_merger_merge", value="Make LoRA (alpha * A - beta * B)",variant='primary') |
| sml_model_a = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Checkpoint A",interactive=True) |
| create_refresh_button(sml_model_a, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z") |
| sml_model_b = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Checkpoint B",interactive=True) |
| create_refresh_button(sml_model_b, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z") |
| with gr.Row().style(equal_height=False): |
| sml_merge = gr.Button(elem_id="model_merger_merge", value="Merge LoRAs",variant='primary') |
| alpha = gr.Slider(label="alpha", minimum=-1.0, maximum=2, step=0.001, value=1) |
| beta = gr.Slider(label="beta", minimum=-1.0, maximum=2, step=0.001, value=1) |
| with gr.Row().style(equal_height=False): |
| sml_settings = gr.CheckboxGroup(["same to Strength", "overwrite"], label="settings") |
| precision = gr.Radio(label = "save precision",choices=["float","fp16","bf16"],value = "fp16",type="value") |
| with gr.Row().style(equal_height=False): |
| sml_dim = gr.Radio(label = "remake dimension",choices = ["no","auto",*[2**(x+2) for x in range(9)]],value = "no",type = "value") |
| sml_filename = gr.Textbox(label="filename(option)",lines=1,visible =True,interactive = True) |
| sml_loranames = gr.Textbox(label='LoRAname1:ratio1:Blocks1,LoRAname2:ratio2:Blocks2,...(":blocks" is option, not necessary)',lines=1,value="",visible =True) |
| sml_dims = gr.CheckboxGroup(label = "limit dimension",choices=[],value = [],type="value",interactive=True,visible = False) |
| with gr.Row().style(equal_height=False): |
| sml_calcdim = gr.Button(elem_id="calcloras", value="calculate dimension of LoRAs(It may take a few minutes if there are many LoRAs)",variant='primary') |
| sml_update = gr.Button(elem_id="calcloras", value="update list",variant='primary') |
| sml_loras = gr.CheckboxGroup(label = "Lora",choices=[x[0] for x in lora.available_loras.items()],type="value",interactive=True,visible = True) |
| sml_loraratios = gr.TextArea(label="",value=sml_lbwpresets,visible =True,interactive = True) |
|
|
| sml_merge.click( |
| fn=lmerge, |
| inputs=[sml_loranames,sml_loraratios,sml_settings,sml_filename,sml_dim,precision], |
| outputs=[sml_submit_result] |
| ) |
|
|
| sml_makelora.click( |
| fn=makelora, |
| inputs=[sml_model_a,sml_model_b,sml_dim,sml_filename,sml_settings,alpha,beta,precision], |
| outputs=[sml_submit_result] |
| ) |
|
|
| sml_cpmerge.click( |
| fn=pluslora, |
| inputs=[sml_loranames,sml_loraratios,sml_settings,sml_filename,sml_model_a,precision], |
| outputs=[sml_submit_result] |
| ) |
| llist ={} |
| dlist =[] |
| dn = [] |
|
|
| def updateloras(): |
| lora.list_available_loras() |
| for n in lora.available_loras.items(): |
| if n[0] not in llist:llist[n[0]] = "" |
| return gr.update(choices = [f"{x[0]}({x[1]})" for x in llist.items()]) |
|
|
| sml_update.click(fn = updateloras,outputs = [sml_loras]) |
|
|
| def calculatedim(): |
| print("listing dimensions...") |
| for n in tqdm(lora.available_loras.items()): |
| if n[0] in llist: |
| if llist[n[0]] !="": continue |
| c_lora = lora.available_loras.get(n[0], None) |
| d,t = dimgetter(c_lora.filename) |
| if t == "LoCon" : d = f"{d}:{t}" |
| if d not in dlist: |
| if type(d) == int :dlist.append(d) |
| elif d not in dn: dn.append(d) |
| llist[n[0]] = d |
| dlist.sort() |
| return gr.update(choices = [f"{x[0]}({x[1]})" for x in llist.items()],value =[]),gr.update(visible =True,choices = [x for x in (dlist+dn)]) |
|
|
| sml_calcdim.click( |
| fn=calculatedim, |
| inputs=[], |
| outputs=[sml_loras,sml_dims] |
| ) |
|
|
| def dimselector(dims): |
| if dims ==[]:return gr.update(choices = [f"{x[0]}({x[1]})" for x in llist.items()]) |
| rl=[] |
| for d in dims: |
| for i in llist.items(): |
| if d == i[1]:rl.append(f"{i[0]}({i[1]})") |
| return gr.update(choices = [l for l in rl],value =[]) |
|
|
| def llister(names): |
| if names ==[] : return "" |
| else: |
| for i,n in enumerate(names): |
| if "(" in n:names[i] = n[:n.rfind("(")] |
| return ":1.0,".join(names)+":1.0" |
| sml_loras.change(fn=llister,inputs=[sml_loras],outputs=[sml_loranames]) |
| sml_dims.change(fn=dimselector,inputs=[sml_dims],outputs=[sml_loras]) |
|
|
| def makelora(model_a,model_b,dim,saveto,settings,alpha,beta,precision): |
| print("make LoRA start") |
| if model_a == "" or model_b =="": |
| return "ERROR: No model Selected" |
| gc.collect() |
|
|
| if saveto =="" : saveto = makeloraname(model_a,model_b) |
| if not ".safetensors" in saveto :saveto += ".safetensors" |
| saveto = os.path.join(shared.cmd_opts.lora_dir,saveto) |
|
|
| dim = 128 if type(dim) != int else int(dim) |
| if os.path.isfile(saveto ) and not "overwrite" in settings: |
| _err_msg = f"Output file ({saveto}) existed and was not saved" |
| print(_err_msg) |
| return _err_msg |
|
|
| svd(fullpathfromname(model_a),fullpathfromname(model_b),False,dim,precision,saveto,alpha,beta) |
| return f"saved to {saveto}" |
|
|
| def lmerge(loranames,loraratioss,settings,filename,dim,precision): |
| import lora |
| loras_on_disk = [lora.available_loras.get(name, None) for name in loranames] |
| if any([x is None for x in loras_on_disk]): |
| lora.list_available_loras() |
|
|
| loras_on_disk = [lora.available_loras.get(name, None) for name in loranames] |
|
|
| lnames = [loranames] if "," not in loranames else loranames.split(",") |
|
|
| for i, n in enumerate(lnames): |
| lnames[i] = n.split(":") |
|
|
| loraratios=loraratioss.splitlines() |
| ldict ={} |
|
|
| for i,l in enumerate(loraratios): |
| if ":" not in l or not (l.count(",") == 16 or l.count(",") == 25) : continue |
| ldict[l.split(":")[0]]=l.split(":")[1] |
|
|
| ln = [] |
| lr = [] |
| ld = [] |
| lt = [] |
| dmax = 1 |
|
|
| for i,n in enumerate(lnames): |
| if len(n) ==3: |
| if n[2].strip() in ldict: |
| ratio = [float(r)*float(n[1]) for r in ldict[n[2]].split(",")] |
| else:ratio = [float(n[1])]*17 |
| else:ratio = [float(n[1])]*17 |
| c_lora = lora.available_loras.get(n[0], None) |
| ln.append(c_lora.filename) |
| lr.append(ratio) |
| d,t = dimgetter(c_lora.filename) |
| lt.append(t) |
| ld.append(d) |
| if d != "LyCORIS": |
| if d > dmax : dmax = d |
|
|
| if filename =="":filename =loranames.replace(",","+").replace(":","_") |
| if not ".safetensors" in filename:filename += ".safetensors" |
| filename = os.path.join(shared.cmd_opts.lora_dir,filename) |
| |
| dim = int(dim) if dim != "no" and dim != "auto" else 0 |
|
|
| if "LyCORIS" in ld or "LoCon" in lt: |
| if len(ld) !=1: |
| return "multiple merge of LyCORIS is not supported" |
| sd = lycomerge(ln[0],lr[0]) |
| elif dim > 0: |
| print("change demension to ", dim) |
| sd = merge_lora_models_dim(ln, lr, dim,settings) |
| elif "auto" in settings and ld.count(ld[0]) != len(ld): |
| print("change demension to ",dmax) |
| sd = merge_lora_models_dim(ln, lr, dmax,settings) |
| else: |
| sd = merge_lora_models(ln, lr,settings) |
|
|
| if os.path.isfile(filename) and not "overwrite" in settings: |
| _err_msg = f"Output file ({filename}) existed and was not saved" |
| print(_err_msg) |
| return _err_msg |
|
|
| save_to_file(filename,sd,sd, str_to_dtype(precision)) |
| return "saved : "+filename |
|
|
| def pluslora(lnames,loraratios,settings,output,model,precision): |
| if model == []: |
| return "ERROR: No model Selected" |
| if lnames == "": |
| return "ERROR: No LoRA Selected" |
|
|
| print("plus LoRA start") |
| import lora |
| lnames = [lnames] if "," not in lnames else lnames.split(",") |
|
|
| for i, n in enumerate(lnames): |
| lnames[i] = n.split(":") |
|
|
| loraratios=loraratios.splitlines() |
| ldict ={} |
|
|
| for i,l in enumerate(loraratios): |
| if ":" not in l or not (l.count(",") == 16 or l.count(",") == 25) : continue |
| ldict[l.split(":")[0].strip()]=l.split(":")[1] |
|
|
| names=[] |
| filenames=[] |
| loratypes=[] |
| lweis=[] |
|
|
| for n in lnames: |
| if len(n) ==3: |
| if n[2].strip() in ldict: |
| ratio = [float(r)*float(n[1]) for r in ldict[n[2]].split(",")] |
| else:ratio = [float(n[1])]*17 |
| else:ratio = [float(n[1])]*17 |
| c_lora = lora.available_loras.get(n[0], None) |
| names.append(n[0]) |
| filenames.append(c_lora.filename) |
| _,t = dimgetter(c_lora.filename) |
| if "LyCORIS" in t: return "LyCORIS merge is not supported" |
| lweis.append(ratio) |
|
|
| modeln=filenamecutter(model,True) |
| dname = modeln |
| for n in names: |
| dname = dname + "+"+n |
|
|
| checkpoint_info = sd_models.get_closet_checkpoint_match(model) |
| print(f"Loading {model}") |
| theta_0 = sd_models.read_state_dict(checkpoint_info.filename,"cpu") |
|
|
| keychanger = {} |
| for key in theta_0.keys(): |
| if "model" in key: |
| skey = key.replace(".","_").replace("_weight","") |
| keychanger[skey.split("model_",1)[1]] = key |
|
|
| for name,filename, lwei in zip(names,filenames, lweis): |
| print(f"loading: {name}") |
| lora_sd = load_state_dict(filename, torch.float) |
|
|
| print(f"merging..." ,lwei) |
| for key in lora_sd.keys(): |
| ratio = 1 |
|
|
| fullkey = convert_diffusers_name_to_compvis(key) |
|
|
| for i,block in enumerate(LORABLOCKS): |
| if block in fullkey: |
| ratio = lwei[i] |
|
|
| msd_key, lora_key = fullkey.split(".", 1) |
|
|
| if "lora_down" in key: |
| up_key = key.replace("lora_down", "lora_up") |
| alpha_key = key[:key.index("lora_down")] + 'alpha' |
|
|
| |
|
|
| down_weight = lora_sd[key].to(device="cpu") |
| up_weight = lora_sd[up_key].to(device="cpu") |
|
|
| dim = down_weight.size()[0] |
| alpha = lora_sd.get(alpha_key, dim) |
| scale = alpha / dim |
| |
| weight = theta_0[keychanger[msd_key]].to(device="cpu") |
|
|
| if not len(down_weight.size()) == 4: |
| |
| weight = weight + ratio * (up_weight @ down_weight) * scale |
| else: |
| |
| weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2) |
| ).unsqueeze(2).unsqueeze(3) * scale |
| theta_0[keychanger[msd_key]] = torch.nn.Parameter(weight) |
| |
| settings.append(precision) |
| result = savemodel(theta_0,dname,output,settings,model) |
| del theta_0 |
| gc.collect() |
| return result |
|
|
| def save_to_file(file_name, model, state_dict, dtype): |
| if dtype is not None: |
| for key in list(state_dict.keys()): |
| if type(state_dict[key]) == torch.Tensor: |
| state_dict[key] = state_dict[key].to(dtype) |
|
|
| if os.path.splitext(file_name)[1] == '.safetensors': |
| save_file(model, file_name) |
| else: |
| torch.save(model, file_name) |
|
|
| re_digits = re.compile(r"\d+") |
|
|
| re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)") |
| re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)") |
| re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)") |
|
|
| re_unet_down_blocks_res = re.compile(r"lora_unet_down_blocks_(\d+)_resnets_(\d+)_(.+)") |
| re_unet_mid_blocks_res = re.compile(r"lora_unet_mid_block_resnets_(\d+)_(.+)") |
| re_unet_up_blocks_res = re.compile(r"lora_unet_up_blocks_(\d+)_resnets_(\d+)_(.+)") |
|
|
| re_unet_downsample = re.compile(r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv(.+)") |
| re_unet_upsample = re.compile(r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv(.+)") |
|
|
| re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)") |
|
|
|
|
| def convert_diffusers_name_to_compvis(key): |
| def match(match_list, regex): |
| r = re.match(regex, key) |
| if not r: |
| return False |
|
|
| match_list.clear() |
| match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) |
| return True |
|
|
| m = [] |
|
|
| if match(m, re_unet_down_blocks): |
| return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}" |
|
|
| if match(m, re_unet_mid_blocks): |
| return f"diffusion_model_middle_block_1_{m[1]}" |
|
|
| if match(m, re_unet_up_blocks): |
| return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}" |
|
|
| if match(m, re_unet_down_blocks_res): |
| block = f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_0_" |
| if m[2].startswith('conv1'): |
| return f"{block}in_layers_2{m[2][len('conv1'):]}" |
| elif m[2].startswith('conv2'): |
| return f"{block}out_layers_3{m[2][len('conv2'):]}" |
| elif m[2].startswith('time_emb_proj'): |
| return f"{block}emb_layers_1{m[2][len('time_emb_proj'):]}" |
| elif m[2].startswith('conv_shortcut'): |
| return f"{block}skip_connection{m[2][len('conv_shortcut'):]}" |
|
|
| if match(m, re_unet_mid_blocks_res): |
| block = f"diffusion_model_middle_block_{m[0]*2}_" |
| if m[1].startswith('conv1'): |
| return f"{block}in_layers_2{m[1][len('conv1'):]}" |
| elif m[1].startswith('conv2'): |
| return f"{block}out_layers_3{m[1][len('conv2'):]}" |
| elif m[1].startswith('time_emb_proj'): |
| return f"{block}emb_layers_1{m[1][len('time_emb_proj'):]}" |
| elif m[1].startswith('conv_shortcut'): |
| return f"{block}skip_connection{m[1][len('conv_shortcut'):]}" |
|
|
| if match(m, re_unet_up_blocks_res): |
| block = f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_0_" |
| if m[2].startswith('conv1'): |
| return f"{block}in_layers_2{m[2][len('conv1'):]}" |
| elif m[2].startswith('conv2'): |
| return f"{block}out_layers_3{m[2][len('conv2'):]}" |
| elif m[2].startswith('time_emb_proj'): |
| return f"{block}emb_layers_1{m[2][len('time_emb_proj'):]}" |
| elif m[2].startswith('conv_shortcut'): |
| return f"{block}skip_connection{m[2][len('conv_shortcut'):]}" |
|
|
| if match(m, re_unet_downsample): |
| return f"diffusion_model_input_blocks_{m[0]*3+3}_0_op{m[1]}" |
|
|
| if match(m, re_unet_upsample): |
| return f"diffusion_model_output_blocks_{m[0]*3 + 2}_{1+(m[0]!=0)}_conv{m[1]}" |
|
|
| if match(m, re_text_block): |
| return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" |
|
|
| return key |
|
|
| CLAMP_QUANTILE = 0.99 |
| MIN_DIFF = 1e-6 |
|
|
| def str_to_dtype(p): |
| if p == 'float': |
| return torch.float |
| if p == 'fp16': |
| return torch.float16 |
| if p == 'bf16': |
| return torch.bfloat16 |
| return None |
|
|
| def svd(model_a,model_b,v2,dim,save_precision,save_to,alpha,beta): |
| save_dtype = str_to_dtype(save_precision) |
|
|
| if model_a == model_b: |
| text_encoder_t, _, unet_t = load_models_from_stable_diffusion_checkpoint(v2, model_a) |
| text_encoder_o, _, unet_o = text_encoder_t, _, unet_t |
| else: |
| print(f"loading SD model : {model_b}") |
| text_encoder_o, _, unet_o = load_models_from_stable_diffusion_checkpoint(v2, model_b) |
| |
| print(f"loading SD model : {model_a}") |
| text_encoder_t, _, unet_t = load_models_from_stable_diffusion_checkpoint(v2, model_a) |
|
|
| |
| lora_network_o = create_network(1.0, dim, dim, None, text_encoder_o, unet_o) |
| lora_network_t = create_network(1.0, dim, dim, None, text_encoder_t, unet_t) |
| assert len(lora_network_o.text_encoder_loras) == len( |
| lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) " |
| |
| diffs = {} |
| text_encoder_different = False |
| for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)): |
| lora_name = lora_o.lora_name |
| module_o = lora_o.org_module |
| module_t = lora_t.org_module |
| diff = alpha*module_t.weight - beta*module_o.weight |
|
|
| |
| if torch.max(torch.abs(diff)) > MIN_DIFF: |
| text_encoder_different = True |
|
|
| diff = diff.float() |
| diffs[lora_name] = diff |
|
|
| if not text_encoder_different: |
| print("Text encoder is same. Extract U-Net only.") |
| lora_network_o.text_encoder_loras = [] |
| diffs = {} |
|
|
| for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)): |
| lora_name = lora_o.lora_name |
| module_o = lora_o.org_module |
| module_t = lora_t.org_module |
| diff = alpha*module_t.weight - beta*module_o.weight |
| diff = diff.float() |
|
|
| diffs[lora_name] = diff |
|
|
| |
| print("calculating by svd") |
| rank = dim |
| lora_weights = {} |
| with torch.no_grad(): |
| for lora_name, mat in tqdm(list(diffs.items())): |
| conv2d = (len(mat.size()) == 4) |
| if conv2d: |
| mat = mat.squeeze() |
|
|
| U, S, Vh = torch.linalg.svd(mat) |
|
|
| U = U[:, :rank] |
| S = S[:rank] |
| U = U @ torch.diag(S) |
|
|
| Vh = Vh[:rank, :] |
|
|
| dist = torch.cat([U.flatten(), Vh.flatten()]) |
| hi_val = torch.quantile(dist, CLAMP_QUANTILE) |
| low_val = -hi_val |
|
|
| U = U.clamp(low_val, hi_val) |
| Vh = Vh.clamp(low_val, hi_val) |
|
|
| lora_weights[lora_name] = (U, Vh) |
|
|
| |
| lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) |
| lora_sd = lora_network_o.state_dict() |
| print(f"LoRA has {len(lora_sd)} weights.") |
|
|
| for key in list(lora_sd.keys()): |
| if "alpha" in key: |
| continue |
|
|
| lora_name = key.split('.')[0] |
| i = 0 if "lora_up" in key else 1 |
|
|
| weights = lora_weights[lora_name][i] |
| |
| if len(lora_sd[key].size()) == 4: |
| weights = weights.unsqueeze(2).unsqueeze(3) |
|
|
| assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}" |
| lora_sd[key] = weights |
|
|
| |
| info = lora_network_o.load_state_dict(lora_sd) |
| print(f"Loading extracted LoRA weights: {info}") |
|
|
| dir_name = os.path.dirname(save_to) |
| if dir_name and not os.path.exists(dir_name): |
| os.makedirs(dir_name, exist_ok=True) |
|
|
| |
| metadata = {"ss_network_dim": str(dim), "ss_network_alpha": str(dim)} |
|
|
| lora_network_o.save_weights(save_to, save_dtype, metadata) |
| print(f"LoRA weights are saved to: {save_to}") |
| return save_to |
|
|
| def load_state_dict(file_name, dtype): |
| if os.path.splitext(file_name)[1] == '.safetensors': |
| sd = load_file(file_name) |
| else: |
| sd = torch.load(file_name, map_location='cpu') |
| for key in list(sd.keys()): |
| if type(sd[key]) == torch.Tensor: |
| sd[key] = sd[key].to(dtype) |
| return sd |
|
|
| def dimgetter(filename): |
| lora_sd = load_state_dict(filename, torch.float) |
| alpha = None |
| dim = None |
| type = None |
|
|
| if "lora_unet_down_blocks_0_resnets_0_conv1.lora_down.weight" in lora_sd.keys(): |
| type = "LoCon" |
|
|
| for key, value in lora_sd.items(): |
| |
| if alpha is None and 'alpha' in key: |
| alpha = value |
| if dim is None and 'lora_down' in key and len(value.size()) == 2: |
| dim = value.size()[0] |
| if "hada_" in key: |
| dim,type = "LyCORIS","LyCORIS" |
| if alpha is not None and dim is not None: |
| break |
| if alpha is None: |
| alpha = dim |
| if type == None:type = "LoRA" |
| if dim : |
| return dim,type |
| else: |
| return "unknown","unknown" |
|
|
| def blockfromkey(key): |
| fullkey = convert_diffusers_name_to_compvis(key) |
| for i,n in enumerate(LORABLOCKS): |
| if n in fullkey: return i |
| return 0 |
|
|
| def merge_lora_models_dim(models, ratios, new_rank,sets): |
| merged_sd = {} |
| fugou = 1 |
| for model, ratios in zip(models, ratios): |
| merge_dtype = torch.float |
| lora_sd = load_state_dict(model, merge_dtype) |
|
|
| |
| print(f"merging {model}: {ratios}") |
| for key in tqdm(list(lora_sd.keys())): |
| if 'lora_down' not in key: |
| continue |
| lora_module_name = key[:key.rfind(".lora_down")] |
|
|
| down_weight = lora_sd[key] |
| network_dim = down_weight.size()[0] |
|
|
| up_weight = lora_sd[lora_module_name + '.lora_up.weight'] |
| alpha = lora_sd.get(lora_module_name + '.alpha', network_dim) |
|
|
| in_dim = down_weight.size()[1] |
| out_dim = up_weight.size()[0] |
| conv2d = len(down_weight.size()) == 4 |
| |
|
|
| |
| if lora_module_name not in merged_sd: |
| weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype) |
| else: |
| weight = merged_sd[lora_module_name] |
|
|
| ratio = ratios[blockfromkey(key)] |
| if "same to Strength" in sets: |
| ratio, fugou = (ratio**0.5,1) if ratio > 0 else (abs(ratio)**0.5,-1) |
| |
| |
| scale = (alpha / network_dim) |
| if not conv2d: |
| weight = weight + ratio * (up_weight @ down_weight) * scale * fugou |
| else: |
| weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2) |
| ).unsqueeze(2).unsqueeze(3) * scale * fugou |
|
|
| merged_sd[lora_module_name] = weight |
|
|
| |
| print("extract new lora...") |
| merged_lora_sd = {} |
| with torch.no_grad(): |
| for lora_module_name, mat in tqdm(list(merged_sd.items())): |
| conv2d = (len(mat.size()) == 4) |
| if conv2d: |
| mat = mat.squeeze() |
|
|
| U, S, Vh = torch.linalg.svd(mat) |
|
|
| U = U[:, :new_rank] |
| S = S[:new_rank] |
| U = U @ torch.diag(S) |
|
|
| Vh = Vh[:new_rank, :] |
|
|
| dist = torch.cat([U.flatten(), Vh.flatten()]) |
| hi_val = torch.quantile(dist, CLAMP_QUANTILE) |
| low_val = -hi_val |
|
|
| U = U.clamp(low_val, hi_val) |
| Vh = Vh.clamp(low_val, hi_val) |
|
|
| up_weight = U |
| down_weight = Vh |
|
|
| if conv2d: |
| up_weight = up_weight.unsqueeze(2).unsqueeze(3) |
| down_weight = down_weight.unsqueeze(2).unsqueeze(3) |
|
|
| merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous() |
| merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous() |
| merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank) |
|
|
| return merged_lora_sd |
|
|
| def merge_lora_models(models, ratios,sets): |
| base_alphas = {} |
| base_dims = {} |
| merge_dtype = torch.float |
| merged_sd = {} |
| fugou = 1 |
| for model, ratios in zip(models, ratios): |
| print(f"merging {model}: {ratios}") |
| lora_sd = load_state_dict(model, merge_dtype) |
|
|
| |
| alphas = {} |
| dims = {} |
| for key in lora_sd.keys(): |
| if 'alpha' in key: |
| lora_module_name = key[:key.rfind(".alpha")] |
| alpha = float(lora_sd[key].detach().numpy()) |
| alphas[lora_module_name] = alpha |
| if lora_module_name not in base_alphas: |
| base_alphas[lora_module_name] = alpha |
| elif "lora_down" in key: |
| lora_module_name = key[:key.rfind(".lora_down")] |
| dim = lora_sd[key].size()[0] |
| dims[lora_module_name] = dim |
| if lora_module_name not in base_dims: |
| base_dims[lora_module_name] = dim |
|
|
| for lora_module_name in dims.keys(): |
| if lora_module_name not in alphas: |
| alpha = dims[lora_module_name] |
| alphas[lora_module_name] = alpha |
| if lora_module_name not in base_alphas: |
| base_alphas[lora_module_name] = alpha |
| |
| print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") |
|
|
| |
| print(f"merging...") |
| for key in lora_sd.keys(): |
| if 'alpha' in key: |
| continue |
| if "lora_down" in key: dwon = True |
| lora_module_name = key[:key.rfind(".lora_")] |
|
|
| base_alpha = base_alphas[lora_module_name] |
| alpha = alphas[lora_module_name] |
|
|
| ratio = ratios[blockfromkey(key)] |
| if "same to Strength" in sets: |
| ratio, fugou = (ratio**0.5,1) if ratio > 0 else (abs(ratio)**0.5,-1) |
|
|
| if "lora_down" in key: |
| ratio = ratio * fugou |
|
|
| scale = math.sqrt(alpha / base_alpha) * ratio |
|
|
| if key in merged_sd: |
| assert merged_sd[key].size() == lora_sd[key].size( |
| ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" |
| merged_sd[key] = merged_sd[key] + lora_sd[key] * scale |
| else: |
| merged_sd[key] = lora_sd[key] * scale |
| |
| |
| for lora_module_name, alpha in base_alphas.items(): |
| key = lora_module_name + ".alpha" |
| merged_sd[key] = torch.tensor(alpha) |
|
|
| print("merged model") |
| print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") |
|
|
| return merged_sd |
|
|
| def fullpathfromname(name): |
| if hash == "" or hash ==[]: return "" |
| checkpoint_info = sd_models.get_closet_checkpoint_match(name) |
| return checkpoint_info.filename |
|
|
| def makeloraname(model_a,model_b): |
| model_a=filenamecutter(model_a) |
| model_b=filenamecutter(model_b) |
| return "lora_"+model_a+"-"+model_b |
|
|
| def lycomerge(filename,ratios): |
| sd = load_state_dict(filename, torch.float) |
|
|
| if len(ratios) == 17: |
| r0 = 1 |
| ratios = [ratios[0]] + [r0] + ratios[1:3]+ [r0] + ratios[3:5]+[r0] + ratios[5:7]+[r0,r0,r0] + [ratios[7]] + [r0,r0,r0] + ratios[8:] |
|
|
| print("LyCORIS: " , ratios) |
|
|
| keys_failed_to_match = [] |
|
|
| for lkey, weight in sd.items(): |
| ratio = 1 |
| picked = False |
| if 'alpha' in lkey: |
| continue |
| |
| fullkey = convert_diffusers_name_to_compvis(lkey) |
| key, lora_key = fullkey.split(".", 1) |
|
|
| for i,block in enumerate(LYCOBLOCKS): |
| if block in key: |
| ratio = ratios[i] |
| picked = True |
| if not picked: keys_failed_to_match.append(key) |
|
|
| sd[lkey] = weight * math.sqrt(abs(float(ratio))) |
|
|
| if "down" in lkey and ratio < 0: |
| sd[key] = sd[key] * -1 |
| |
| if len(keys_failed_to_match) > 0: |
| print(keys_failed_to_match) |
| |
| return sd |
|
|
| LORABLOCKS=["encoder", |
| "diffusion_model_input_blocks_1_", |
| "diffusion_model_input_blocks_2_", |
| "diffusion_model_input_blocks_4_", |
| "diffusion_model_input_blocks_5_", |
| "diffusion_model_input_blocks_7_", |
| "diffusion_model_input_blocks_8_", |
| "diffusion_model_middle_block_", |
| "diffusion_model_output_blocks_3_", |
| "diffusion_model_output_blocks_4_", |
| "diffusion_model_output_blocks_5_", |
| "diffusion_model_output_blocks_6_", |
| "diffusion_model_output_blocks_7_", |
| "diffusion_model_output_blocks_8_", |
| "diffusion_model_output_blocks_9_", |
| "diffusion_model_output_blocks_10_", |
| "diffusion_model_output_blocks_11_"] |
|
|
| LYCOBLOCKS=["encoder", |
| "diffusion_model_input_blocks_0_", |
| "diffusion_model_input_blocks_1_", |
| "diffusion_model_input_blocks_2_", |
| "diffusion_model_input_blocks_3_", |
| "diffusion_model_input_blocks_4_", |
| "diffusion_model_input_blocks_5_", |
| "diffusion_model_input_blocks_6_", |
| "diffusion_model_input_blocks_7_", |
| "diffusion_model_input_blocks_8_", |
| "diffusion_model_input_blocks_9_", |
| "diffusion_model_input_blocks_10_", |
| "diffusion_model_input_blocks_11_", |
| "diffusion_model_middle_block_", |
| "diffusion_model_output_blocks_0_", |
| "diffusion_model_output_blocks_1_", |
| "diffusion_model_output_blocks_2_", |
| "diffusion_model_output_blocks_3_", |
| "diffusion_model_output_blocks_4_", |
| "diffusion_model_output_blocks_5_", |
| "diffusion_model_output_blocks_6_", |
| "diffusion_model_output_blocks_7_", |
| "diffusion_model_output_blocks_8_", |
| "diffusion_model_output_blocks_9_", |
| "diffusion_model_output_blocks_10_", |
| "diffusion_model_output_blocks_11_"] |
|
|
| class LoRAModule(torch.nn.Module): |
| """ |
| replaces forward method of the original Linear, instead of replacing the original Linear module. |
| """ |
|
|
| def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): |
| """if alpha == 0 or None, alpha is rank (no scaling).""" |
| super().__init__() |
| self.lora_name = lora_name |
|
|
| if org_module.__class__.__name__ == "Conv2d": |
| in_dim = org_module.in_channels |
| out_dim = org_module.out_channels |
| else: |
| in_dim = org_module.in_features |
| out_dim = org_module.out_features |
|
|
| |
| |
| |
| |
| |
| self.lora_dim = lora_dim |
|
|
| if org_module.__class__.__name__ == "Conv2d": |
| kernel_size = org_module.kernel_size |
| stride = org_module.stride |
| padding = org_module.padding |
| self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) |
| self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) |
| else: |
| self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) |
| self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) |
|
|
| if type(alpha) == torch.Tensor: |
| alpha = alpha.detach().float().numpy() |
| alpha = self.lora_dim if alpha is None or alpha == 0 else alpha |
| self.scale = alpha / self.lora_dim |
| self.register_buffer("alpha", torch.tensor(alpha)) |
|
|
| |
| torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) |
| torch.nn.init.zeros_(self.lora_up.weight) |
|
|
| self.multiplier = multiplier |
| self.org_module = org_module |
| self.region = None |
| self.region_mask = None |
|
|
| def apply_to(self): |
| self.org_forward = self.org_module.forward |
| self.org_module.forward = self.forward |
| del self.org_module |
|
|
| def merge_to(self, sd, dtype, device): |
| |
| up_weight = sd["lora_up.weight"].to(torch.float).to(device) |
| down_weight = sd["lora_down.weight"].to(torch.float).to(device) |
|
|
| |
| org_sd = self.org_module.state_dict() |
| weight = org_sd["weight"].to(torch.float) |
|
|
| |
| if len(weight.size()) == 2: |
| |
| weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale |
| elif down_weight.size()[2:4] == (1, 1): |
| |
| weight = ( |
| weight |
| + self.multiplier |
| * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) |
| * self.scale |
| ) |
| else: |
| |
| conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) |
| |
| weight = weight + self.multiplier * conved * self.scale |
|
|
| |
| org_sd["weight"] = weight.to(dtype) |
| self.org_module.load_state_dict(org_sd) |
|
|
| def set_region(self, region): |
| self.region = region |
| self.region_mask = None |
|
|
| def forward(self, x): |
| if self.region is None: |
| return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale |
|
|
| |
| if x.size()[1] % 77 == 0: |
| |
| self.region = None |
| return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale |
|
|
| |
| if self.region_mask is None: |
| if len(x.size()) == 4: |
| h, w = x.size()[2:4] |
| else: |
| seq_len = x.size()[1] |
| ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len) |
| h = int(self.region.size()[0] / ratio + 0.5) |
| w = seq_len // h |
|
|
| r = self.region.to(x.device) |
| if r.dtype == torch.bfloat16: |
| r = r.to(torch.float) |
| r = r.unsqueeze(0).unsqueeze(1) |
| |
| r = torch.nn.functional.interpolate(r, (h, w), mode="bilinear") |
| r = r.to(x.dtype) |
|
|
| if len(x.size()) == 3: |
| r = torch.reshape(r, (1, x.size()[1], -1)) |
|
|
| self.region_mask = r |
|
|
| return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask |
|
|
| def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): |
| if network_dim is None: |
| network_dim = 4 |
|
|
| |
| conv_dim = kwargs.get("conv_dim", None) |
| conv_alpha = kwargs.get("conv_alpha", None) |
| if conv_dim is not None: |
| conv_dim = int(conv_dim) |
| if conv_alpha is None: |
| conv_alpha = 1.0 |
| else: |
| conv_alpha = float(conv_alpha) |
|
|
| """ |
| block_dims = kwargs.get("block_dims") |
| block_alphas = None |
| if block_dims is not None: |
| block_dims = [int(d) for d in block_dims.split(',')] |
| assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}" |
| block_alphas = kwargs.get("block_alphas") |
| if block_alphas is None: |
| block_alphas = [1] * len(block_dims) |
| else: |
| block_alphas = [int(a) for a in block_alphas(',')] |
| assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" |
| conv_block_dims = kwargs.get("conv_block_dims") |
| conv_block_alphas = None |
| if conv_block_dims is not None: |
| conv_block_dims = [int(d) for d in conv_block_dims.split(',')] |
| assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}" |
| conv_block_alphas = kwargs.get("conv_block_alphas") |
| if conv_block_alphas is None: |
| conv_block_alphas = [1] * len(conv_block_dims) |
| else: |
| conv_block_alphas = [int(a) for a in conv_block_alphas(',')] |
| assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" |
| """ |
|
|
| network = LoRANetwork( |
| text_encoder, |
| unet, |
| multiplier=multiplier, |
| lora_dim=network_dim, |
| alpha=network_alpha, |
| conv_lora_dim=conv_dim, |
| conv_alpha=conv_alpha, |
| ) |
| return network |
|
|
|
|
|
|
| class LoRANetwork(torch.nn.Module): |
| |
| UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] |
| UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] |
| TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] |
| LORA_PREFIX_UNET = "lora_unet" |
| LORA_PREFIX_TEXT_ENCODER = "lora_te" |
|
|
| def __init__( |
| self, |
| text_encoder, |
| unet, |
| multiplier=1.0, |
| lora_dim=4, |
| alpha=1, |
| conv_lora_dim=None, |
| conv_alpha=None, |
| modules_dim=None, |
| modules_alpha=None, |
| ) -> None: |
| super().__init__() |
| self.multiplier = multiplier |
|
|
| self.lora_dim = lora_dim |
| self.alpha = alpha |
| self.conv_lora_dim = conv_lora_dim |
| self.conv_alpha = conv_alpha |
|
|
| if modules_dim is not None: |
| print(f"create LoRA network from weights") |
| else: |
| print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") |
|
|
| self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None |
| if self.apply_to_conv2d_3x3: |
| if self.conv_alpha is None: |
| self.conv_alpha = self.alpha |
| print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") |
|
|
| |
| def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: |
| loras = [] |
| for name, module in root_module.named_modules(): |
| if module.__class__.__name__ in target_replace_modules: |
| |
| for child_name, child_module in module.named_modules(): |
| is_linear = child_module.__class__.__name__ == "Linear" |
| is_conv2d = child_module.__class__.__name__ == "Conv2d" |
| is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) |
| if is_linear or is_conv2d: |
| lora_name = prefix + "." + name + "." + child_name |
| lora_name = lora_name.replace(".", "_") |
|
|
| if modules_dim is not None: |
| if lora_name not in modules_dim: |
| continue |
| dim = modules_dim[lora_name] |
| alpha = modules_alpha[lora_name] |
| else: |
| if is_linear or is_conv2d_1x1: |
| dim = self.lora_dim |
| alpha = self.alpha |
| elif self.apply_to_conv2d_3x3: |
| dim = self.conv_lora_dim |
| alpha = self.conv_alpha |
| else: |
| continue |
|
|
| lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha) |
| loras.append(lora) |
| return loras |
|
|
| self.text_encoder_loras = create_modules( |
| LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE |
| ) |
| print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") |
|
|
| |
| target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE |
| if modules_dim is not None or self.conv_lora_dim is not None: |
| target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 |
|
|
| self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules) |
| print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") |
|
|
| self.weights_sd = None |
|
|
| |
| names = set() |
| for lora in self.text_encoder_loras + self.unet_loras: |
| assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" |
| names.add(lora.lora_name) |
|
|
| def set_multiplier(self, multiplier): |
| self.multiplier = multiplier |
| for lora in self.text_encoder_loras + self.unet_loras: |
| lora.multiplier = self.multiplier |
|
|
| def load_weights(self, file): |
| if os.path.splitext(file)[1] == ".safetensors": |
| from safetensors.torch import load_file, safe_open |
|
|
| self.weights_sd = load_file(file) |
| else: |
| self.weights_sd = torch.load(file, map_location="cpu") |
|
|
| def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None): |
| if self.weights_sd: |
| weights_has_text_encoder = weights_has_unet = False |
| for key in self.weights_sd.keys(): |
| if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): |
| weights_has_text_encoder = True |
| elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): |
| weights_has_unet = True |
|
|
| if apply_text_encoder is None: |
| apply_text_encoder = weights_has_text_encoder |
| else: |
| assert ( |
| apply_text_encoder == weights_has_text_encoder |
| ), f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています" |
|
|
| if apply_unet is None: |
| apply_unet = weights_has_unet |
| else: |
| assert ( |
| apply_unet == weights_has_unet |
| ), f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています" |
| else: |
| assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set" |
|
|
| if apply_text_encoder: |
| print("enable LoRA for text encoder") |
| else: |
| self.text_encoder_loras = [] |
|
|
| if apply_unet: |
| print("enable LoRA for U-Net") |
| else: |
| self.unet_loras = [] |
|
|
| for lora in self.text_encoder_loras + self.unet_loras: |
| lora.apply_to() |
| self.add_module(lora.lora_name, lora) |
|
|
| if self.weights_sd: |
| |
| info = self.load_state_dict(self.weights_sd, False) |
| print(f"weights are loaded: {info}") |
|
|
| |
| def merge_to(self, text_encoder, unet, dtype, device): |
| assert self.weights_sd is not None, "weights are not loaded" |
|
|
| apply_text_encoder = apply_unet = False |
| for key in self.weights_sd.keys(): |
| if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): |
| apply_text_encoder = True |
| elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): |
| apply_unet = True |
|
|
| if apply_text_encoder: |
| print("enable LoRA for text encoder") |
| else: |
| self.text_encoder_loras = [] |
|
|
| if apply_unet: |
| print("enable LoRA for U-Net") |
| else: |
| self.unet_loras = [] |
|
|
| for lora in self.text_encoder_loras + self.unet_loras: |
| sd_for_lora = {} |
| for key in self.weights_sd.keys(): |
| if key.startswith(lora.lora_name): |
| sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key] |
| lora.merge_to(sd_for_lora, dtype, device) |
| print(f"weights are merged") |
|
|
| def enable_gradient_checkpointing(self): |
| |
| pass |
|
|
| def prepare_optimizer_params(self, text_encoder_lr, unet_lr): |
| def enumerate_params(loras): |
| params = [] |
| for lora in loras: |
| params.extend(lora.parameters()) |
| return params |
|
|
| self.requires_grad_(True) |
| all_params = [] |
|
|
| if self.text_encoder_loras: |
| param_data = {"params": enumerate_params(self.text_encoder_loras)} |
| if text_encoder_lr is not None: |
| param_data["lr"] = text_encoder_lr |
| all_params.append(param_data) |
|
|
| if self.unet_loras: |
| param_data = {"params": enumerate_params(self.unet_loras)} |
| if unet_lr is not None: |
| param_data["lr"] = unet_lr |
| all_params.append(param_data) |
|
|
| return all_params |
|
|
| def prepare_grad_etc(self, text_encoder, unet): |
| self.requires_grad_(True) |
|
|
| def on_epoch_start(self, text_encoder, unet): |
| self.train() |
|
|
| def get_trainable_params(self): |
| return self.parameters() |
|
|
| def save_weights(self, file, dtype, metadata): |
| if metadata is not None and len(metadata) == 0: |
| metadata = None |
|
|
| state_dict = self.state_dict() |
|
|
| if dtype is not None: |
| for key in list(state_dict.keys()): |
| v = state_dict[key] |
| v = v.detach().clone().to("cpu").to(dtype) |
| state_dict[key] = v |
|
|
| if os.path.splitext(file)[1] == ".safetensors": |
| from safetensors.torch import save_file |
|
|
| |
| if metadata is None: |
| metadata = {} |
| model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata) |
| metadata["sshs_model_hash"] = model_hash |
| metadata["sshs_legacy_hash"] = legacy_hash |
|
|
| save_file(state_dict, file, metadata) |
| else: |
| torch.save(state_dict, file) |
|
|
| @staticmethod |
| def set_regions(networks, image): |
| image = image.astype(np.float32) / 255.0 |
| for i, network in enumerate(networks[:3]): |
| |
| region = image[:, :, i] |
| if region.max() == 0: |
| continue |
| region = torch.tensor(region) |
| network.set_region(region) |
|
|
| def set_region(self, region): |
| for lora in self.unet_loras: |
| lora.set_region(region) |
|
|
| from io import BytesIO |
| import safetensors.torch |
| import hashlib |
|
|
| def precalculate_safetensors_hashes(tensors, metadata): |
| """Precalculate the model hashes needed by sd-webui-additional-networks to |
| save time on indexing the model later.""" |
|
|
| |
| |
| |
| metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} |
|
|
| bytes = safetensors.torch.save(tensors, metadata) |
| b = BytesIO(bytes) |
|
|
| model_hash = addnet_hash_safetensors(b) |
| legacy_hash = addnet_hash_legacy(b) |
| return model_hash, legacy_hash |
|
|
| def addnet_hash_safetensors(b): |
| """New model hash used by sd-webui-additional-networks for .safetensors format files""" |
| hash_sha256 = hashlib.sha256() |
| blksize = 1024 * 1024 |
|
|
| b.seek(0) |
| header = b.read(8) |
| n = int.from_bytes(header, "little") |
|
|
| offset = n + 8 |
| b.seek(offset) |
| for chunk in iter(lambda: b.read(blksize), b""): |
| hash_sha256.update(chunk) |
|
|
| return hash_sha256.hexdigest() |
|
|
| def addnet_hash_legacy(b): |
| """Old model hash used by sd-webui-additional-networks for .safetensors format files""" |
| m = hashlib.sha256() |
|
|
| b.seek(0x100000) |
| m.update(b.read(0x10000)) |
| return m.hexdigest()[0:8] |
|
|