| from modules.patch import patch_all |
|
|
| patch_all() |
|
|
|
|
| import os |
| import einops |
| import torch |
| import numpy as np |
|
|
| import fcbh.model_management |
| import fcbh.model_detection |
| import fcbh.model_patcher |
| import fcbh.utils |
| import fcbh.controlnet |
| import modules.sample_hijack |
| import fcbh.samplers |
| import fcbh.latent_formats |
| import modules.advanced_parameters |
|
|
| from fcbh.sd import load_checkpoint_guess_config |
| from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled, \ |
| ControlNetApplyAdvanced |
| from fcbh_extras.nodes_freelunch import FreeU_V2 |
| from fcbh.sample import prepare_mask |
| from modules.patch import patched_sampler_cfg_function |
| from fcbh.lora import model_lora_keys_unet, model_lora_keys_clip, load_lora |
| from modules.config import path_embeddings |
| from modules.lora import load_dangerous_lora |
| from fcbh_extras.nodes_model_advanced import ModelSamplingDiscrete |
|
|
|
|
| opEmptyLatentImage = EmptyLatentImage() |
| opVAEDecode = VAEDecode() |
| opVAEEncode = VAEEncode() |
| opVAEDecodeTiled = VAEDecodeTiled() |
| opVAEEncodeTiled = VAEEncodeTiled() |
| opControlNetApplyAdvanced = ControlNetApplyAdvanced() |
| opFreeU = FreeU_V2() |
| opModelSamplingDiscrete = ModelSamplingDiscrete() |
|
|
|
|
| class StableDiffusionModel: |
| def __init__(self, unet=None, vae=None, clip=None, clip_vision=None, filename=None): |
| self.unet = unet |
| self.vae = vae |
| self.clip = clip |
| self.clip_vision = clip_vision |
| self.filename = filename |
| self.unet_with_lora = unet |
| self.clip_with_lora = clip |
| self.visited_loras = '' |
| self.lora_key_map = {} |
|
|
| if self.unet is not None: |
| self.lora_key_map = model_lora_keys_unet(self.unet.model, self.lora_key_map) |
| self.lora_key_map.update({x: x for x in self.unet.model.state_dict().keys()}) |
|
|
| if self.clip is not None: |
| self.lora_key_map = model_lora_keys_clip(self.clip.cond_stage_model, self.lora_key_map) |
| self.lora_key_map.update({x: x for x in self.clip.cond_stage_model.state_dict().keys()}) |
|
|
| @torch.no_grad() |
| @torch.inference_mode() |
| def refresh_loras(self, loras): |
| assert isinstance(loras, list) |
|
|
| if self.visited_loras == str(loras): |
| return |
|
|
| self.visited_loras = str(loras) |
| loras_to_load = [] |
|
|
| if self.unet is None: |
| return |
|
|
| print(f'Request to load LoRAs {str(loras)} for model [{self.filename}].') |
|
|
| for name, weight in loras: |
| if name == 'None': |
| continue |
|
|
| if os.path.exists(name): |
| lora_filename = name |
| else: |
| lora_filename = os.path.join(modules.config.path_loras, name) |
|
|
| if not os.path.exists(lora_filename): |
| print(f'Lora file not found: {lora_filename}') |
| continue |
|
|
| loras_to_load.append((lora_filename, weight)) |
|
|
| self.unet_with_lora = self.unet.clone() if self.unet is not None else None |
| self.clip_with_lora = self.clip.clone() if self.clip is not None else None |
|
|
| for lora_filename, weight in loras_to_load: |
| lora = fcbh.utils.load_torch_file(lora_filename, safe_load=False) |
| lora_items = load_dangerous_lora(lora, self.lora_key_map) |
|
|
| if len(lora_items) == 0: |
| continue |
| |
| print(f'Loaded LoRA [{lora_filename}] for model [{self.filename}] with {len(lora_items)} keys at weight {weight}.') |
|
|
| if self.unet_with_lora is not None: |
| loaded_unet_keys = self.unet_with_lora.add_patches(lora_items, weight) |
| else: |
| loaded_unet_keys = [] |
|
|
| if self.clip_with_lora is not None: |
| loaded_clip_keys = self.clip_with_lora.add_patches(lora_items, weight) |
| else: |
| loaded_clip_keys = [] |
|
|
| for item in lora_items: |
| if item not in set(list(loaded_unet_keys) + list(loaded_clip_keys)): |
| print("LoRA key skipped: ", item) |
|
|
|
|
| @torch.no_grad() |
| @torch.inference_mode() |
| def apply_freeu(model, b1, b2, s1, s2): |
| return opFreeU.patch(model=model, b1=b1, b2=b2, s1=s1, s2=s2)[0] |
|
|
|
|
| @torch.no_grad() |
| @torch.inference_mode() |
| def load_controlnet(ckpt_filename): |
| return fcbh.controlnet.load_controlnet(ckpt_filename) |
|
|
|
|
| @torch.no_grad() |
| @torch.inference_mode() |
| def apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent): |
| return opControlNetApplyAdvanced.apply_controlnet(positive=positive, negative=negative, control_net=control_net, |
| image=image, strength=strength, start_percent=start_percent, end_percent=end_percent) |
|
|
|
|
| @torch.no_grad() |
| @torch.inference_mode() |
| def load_model(ckpt_filename): |
| unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings) |
| unet.model_options['sampler_cfg_function'] = patched_sampler_cfg_function |
| return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename) |
|
|
|
|
| @torch.no_grad() |
| @torch.inference_mode() |
| def load_sd_lora(model, lora_filename, strength_model=1.0, strength_clip=1.0): |
| if strength_model == 0 and strength_clip == 0: |
| return model |
|
|
| lora = fcbh.utils.load_torch_file(lora_filename, safe_load=False) |
|
|
| if lora_filename.lower().endswith('.fooocus.patch'): |
| loaded = lora |
| else: |
| key_map = model_lora_keys_unet(model.unet.model) |
| key_map = model_lora_keys_clip(model.clip.cond_stage_model, key_map) |
| loaded = load_lora(lora, key_map) |
|
|
| new_unet = model.unet.clone() |
| loaded_unet_keys = new_unet.add_patches(loaded, strength_model) |
|
|
| new_clip = model.clip.clone() |
| loaded_clip_keys = new_clip.add_patches(loaded, strength_clip) |
|
|
| loaded_keys = set(list(loaded_unet_keys) + list(loaded_clip_keys)) |
|
|
| for x in loaded: |
| if x not in loaded_keys: |
| print("Lora key not loaded: ", x) |
|
|
| return StableDiffusionModel(unet=new_unet, clip=new_clip, vae=model.vae, clip_vision=model.clip_vision) |
|
|
|
|
| @torch.no_grad() |
| @torch.inference_mode() |
| def generate_empty_latent(width=1024, height=1024, batch_size=1): |
| return opEmptyLatentImage.generate(width=width, height=height, batch_size=batch_size)[0] |
|
|
|
|
| @torch.no_grad() |
| @torch.inference_mode() |
| def decode_vae(vae, latent_image, tiled=False): |
| if tiled: |
| return opVAEDecodeTiled.decode(samples=latent_image, vae=vae, tile_size=512)[0] |
| else: |
| return opVAEDecode.decode(samples=latent_image, vae=vae)[0] |
|
|
|
|
| @torch.no_grad() |
| @torch.inference_mode() |
| def encode_vae(vae, pixels, tiled=False): |
| if tiled: |
| return opVAEEncodeTiled.encode(pixels=pixels, vae=vae, tile_size=512)[0] |
| else: |
| return opVAEEncode.encode(pixels=pixels, vae=vae)[0] |
|
|
|
|
| @torch.no_grad() |
| @torch.inference_mode() |
| def encode_vae_inpaint(vae, pixels, mask): |
| assert mask.ndim == 3 and pixels.ndim == 4 |
| assert mask.shape[-1] == pixels.shape[-2] |
| assert mask.shape[-2] == pixels.shape[-3] |
|
|
| w = mask.round()[..., None] |
| pixels = pixels * (1 - w) + 0.5 * w |
|
|
| latent = vae.encode(pixels) |
| B, C, H, W = latent.shape |
|
|
| latent_mask = mask[:, None, :, :] |
| latent_mask = torch.nn.functional.interpolate(latent_mask, size=(H * 8, W * 8), mode="bilinear").round() |
| latent_mask = torch.nn.functional.max_pool2d(latent_mask, (8, 8)).round() |
|
|
| return latent, latent_mask |
|
|
|
|
| class VAEApprox(torch.nn.Module): |
| def __init__(self): |
| super(VAEApprox, self).__init__() |
| self.conv1 = torch.nn.Conv2d(4, 8, (7, 7)) |
| self.conv2 = torch.nn.Conv2d(8, 16, (5, 5)) |
| self.conv3 = torch.nn.Conv2d(16, 32, (3, 3)) |
| self.conv4 = torch.nn.Conv2d(32, 64, (3, 3)) |
| self.conv5 = torch.nn.Conv2d(64, 32, (3, 3)) |
| self.conv6 = torch.nn.Conv2d(32, 16, (3, 3)) |
| self.conv7 = torch.nn.Conv2d(16, 8, (3, 3)) |
| self.conv8 = torch.nn.Conv2d(8, 3, (3, 3)) |
| self.current_type = None |
|
|
| def forward(self, x): |
| extra = 11 |
| x = torch.nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2)) |
| x = torch.nn.functional.pad(x, (extra, extra, extra, extra)) |
| for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8]: |
| x = layer(x) |
| x = torch.nn.functional.leaky_relu(x, 0.1) |
| return x |
|
|
|
|
| VAE_approx_models = {} |
|
|
|
|
| @torch.no_grad() |
| @torch.inference_mode() |
| def get_previewer(model): |
| global VAE_approx_models |
|
|
| from modules.config import path_vae_approx |
| is_sdxl = isinstance(model.model.latent_format, fcbh.latent_formats.SDXL) |
| vae_approx_filename = os.path.join(path_vae_approx, 'xlvaeapp.pth' if is_sdxl else 'vaeapp_sd15.pth') |
|
|
| if vae_approx_filename in VAE_approx_models: |
| VAE_approx_model = VAE_approx_models[vae_approx_filename] |
| else: |
| sd = torch.load(vae_approx_filename, map_location='cpu') |
| VAE_approx_model = VAEApprox() |
| VAE_approx_model.load_state_dict(sd) |
| del sd |
| VAE_approx_model.eval() |
|
|
| if fcbh.model_management.should_use_fp16(): |
| VAE_approx_model.half() |
| VAE_approx_model.current_type = torch.float16 |
| else: |
| VAE_approx_model.float() |
| VAE_approx_model.current_type = torch.float32 |
|
|
| VAE_approx_model.to(fcbh.model_management.get_torch_device()) |
| VAE_approx_models[vae_approx_filename] = VAE_approx_model |
|
|
| @torch.no_grad() |
| @torch.inference_mode() |
| def preview_function(x0, step, total_steps): |
| with torch.no_grad(): |
| x_sample = x0.to(VAE_approx_model.current_type) |
| x_sample = VAE_approx_model(x_sample) * 127.5 + 127.5 |
| x_sample = einops.rearrange(x_sample, 'b c h w -> b h w c')[0] |
| x_sample = x_sample.cpu().numpy().clip(0, 255).astype(np.uint8) |
| return x_sample |
|
|
| return preview_function |
|
|
|
|
| @torch.no_grad() |
| @torch.inference_mode() |
| def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sampler_name='dpmpp_2m_sde_gpu', |
| scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None, |
| force_full_denoise=False, callback_function=None, refiner=None, refiner_switch=-1, |
| previewer_start=None, previewer_end=None, sigmas=None, noise_mean=None): |
|
|
| if sigmas is not None: |
| sigmas = sigmas.clone().to(fcbh.model_management.get_torch_device()) |
|
|
| latent_image = latent["samples"] |
|
|
| if disable_noise: |
| noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") |
| else: |
| batch_inds = latent["batch_index"] if "batch_index" in latent else None |
| noise = fcbh.sample.prepare_noise(latent_image, seed, batch_inds) |
|
|
| if isinstance(noise_mean, torch.Tensor): |
| noise = noise + noise_mean - torch.mean(noise, dim=1, keepdim=True) |
|
|
| noise_mask = None |
| if "noise_mask" in latent: |
| noise_mask = latent["noise_mask"] |
|
|
| previewer = get_previewer(model) |
|
|
| if previewer_start is None: |
| previewer_start = 0 |
|
|
| if previewer_end is None: |
| previewer_end = steps |
|
|
| def callback(step, x0, x, total_steps): |
| fcbh.model_management.throw_exception_if_processing_interrupted() |
| y = None |
| if previewer is not None and not modules.advanced_parameters.disable_preview: |
| y = previewer(x0, previewer_start + step, previewer_end) |
| if callback_function is not None: |
| callback_function(previewer_start + step, x0, x, previewer_end, y) |
|
|
| disable_pbar = False |
| modules.sample_hijack.current_refiner = refiner |
| modules.sample_hijack.refiner_switch_step = refiner_switch |
| fcbh.samplers.sample = modules.sample_hijack.sample_hacked |
|
|
| try: |
| samples = fcbh.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, |
| denoise=denoise, disable_noise=disable_noise, start_step=start_step, |
| last_step=last_step, |
| force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, |
| disable_pbar=disable_pbar, seed=seed, sigmas=sigmas) |
|
|
| out = latent.copy() |
| out["samples"] = samples |
| finally: |
| modules.sample_hijack.current_refiner = None |
|
|
| return out |
|
|
|
|
| @torch.no_grad() |
| @torch.inference_mode() |
| def pytorch_to_numpy(x): |
| return [np.clip(255. * y.cpu().numpy(), 0, 255).astype(np.uint8) for y in x] |
|
|
|
|
| @torch.no_grad() |
| @torch.inference_mode() |
| def numpy_to_pytorch(x): |
| y = x.astype(np.float32) / 255.0 |
| y = y[None] |
| y = np.ascontiguousarray(y.copy()) |
| y = torch.from_numpy(y).float() |
| return y |
|
|