"""CRS-Diff modular loading utilities for custom diffusers pipeline.""" import importlib import json import sys from pathlib import Path from typing import Dict, Optional, Union import torch from diffusers import DDIMScheduler _PIPELINE_DIR = Path(__file__).resolve().parent if str(_PIPELINE_DIR) not in sys.path: sys.path.insert(0, str(_PIPELINE_DIR)) _COMPONENT_NAMES = ( "unet", "vae", "text_encoder", "local_adapter", "global_content_adapter", "global_text_adapter", "metadata_encoder", ) _TARGET_MAP = { "crs_core.local_adapter.LocalControlUNetModel": "crs_core.local_adapter.LocalControlUNetModel", "crs_core.autoencoder.AutoencoderKL": "crs_core.autoencoder.AutoencoderKL", "crs_core.text_encoder.FrozenCLIPEmbedder": "crs_core.text_encoder.FrozenCLIPEmbedder", "crs_core.local_adapter.LocalAdapter": "crs_core.local_adapter.LocalAdapter", "crs_core.global_adapter.GlobalContentAdapter": "crs_core.global_adapter.GlobalContentAdapter", "crs_core.global_adapter.GlobalTextAdapter": "crs_core.global_adapter.GlobalTextAdapter", "crs_core.metadata_embedding.metadata_embeddings": "crs_core.metadata_embedding.metadata_embeddings", } def ensure_model_path(pretrained_model_name_or_path: Union[str, Path]) -> Path: """Resolve local path or download HF repo snapshot.""" path = Path(pretrained_model_name_or_path) if not path.exists(): from huggingface_hub import snapshot_download path = Path(snapshot_download(str(pretrained_model_name_or_path))) path = path.resolve() if str(path) not in sys.path: sys.path.insert(0, str(path)) return path def resolve_model_root(candidate: Optional[Union[str, Path]]) -> Optional[Path]: """Resolve to folder containing model_index.json.""" if not candidate: return None path = ensure_model_path(candidate) if (path / "model_index.json").exists(): return path cur = path for _ in range(5): parent = cur.parent if parent == cur: break if (parent / "model_index.json").exists(): return parent cur = parent return None def _get_class(target: str): module_path, cls_name = target.rsplit(".", 1) mod = importlib.import_module(module_path) return getattr(mod, cls_name) def load_component(model_root: Path, name: str): """Load single split component from //.""" root = Path(model_root) comp_path = root / name with (comp_path / "config.json").open("r", encoding="utf-8") as f: cfg = json.load(f) target = cfg.pop("_target", None) if not target: raise ValueError(f"No _target in {comp_path / 'config.json'}") target = _TARGET_MAP.get(target, target) cls_ref = _get_class(target) params = {k: v for k, v in cfg.items() if not k.startswith("_")} module = cls_ref(**params) weight_file = comp_path / "diffusion_pytorch_model.safetensors" if weight_file.exists(): from safetensors.torch import load_file state = load_file(str(weight_file)) module.load_state_dict(state, strict=True) module.eval() return module class CRSModelWrapper(torch.nn.Module): """Wrap split components to mimic CRSControlNet APIs used by pipeline.""" def __init__( self, unet, vae, text_encoder, local_adapter, global_content_adapter, global_text_adapter, metadata_encoder, channels: int = 4, ): super().__init__() self.model = torch.nn.Module() self.model.add_module("diffusion_model", unet) self.first_stage_model = vae self.cond_stage_model = text_encoder self.local_adapter = local_adapter self.global_content_adapter = global_content_adapter self.global_text_adapter = global_text_adapter self.metadata_emb = metadata_encoder self.local_control_scales = [1.0] * 13 self.channels = channels @torch.no_grad() def get_learned_conditioning(self, prompts): if hasattr(self.cond_stage_model, "device"): self.cond_stage_model.device = str(next(self.parameters()).device) return self.cond_stage_model.encode(prompts) def apply_model(self, x_noisy, t, cond, metadata=None, global_strength=1.0, **kwargs): del kwargs if metadata is None: metadata = cond["metadata"] cond_txt = torch.cat(cond["c_crossattn"], 1) if cond.get("global_control") is not None and cond["global_control"][0] is not None: metadata = self.metadata_emb(metadata) content_t, _ = cond["global_control"][0].chunk(2, dim=1) global_control = self.global_content_adapter(content_t) cond_txt = self.global_text_adapter(cond_txt) cond_txt = torch.cat([cond_txt, global_strength * global_control], dim=1) local_control = None if cond.get("local_control") is not None and cond["local_control"][0] is not None: local_control = torch.cat(cond["local_control"], 1) local_control = self.local_adapter( x=x_noisy, timesteps=t, context=cond_txt, local_conditions=local_control ) local_control = [c * s for c, s in zip(local_control, self.local_control_scales)] return self.model.diffusion_model( x=x_noisy, timesteps=t, metadata=metadata, context=cond_txt, local_control=local_control, meta=True, ) def decode_first_stage(self, z): return self.first_stage_model.decode(z) def load_components(model_root: Union[str, Path]) -> Dict[str, object]: """Load pipeline components from split directories.""" root = ensure_model_path(model_root) scheduler = DDIMScheduler.from_pretrained(root, subfolder="scheduler") scale_factor = 0.18215 channels = 4 if (root / "model_index.json").exists(): with (root / "model_index.json").open("r", encoding="utf-8") as f: idx = json.load(f) scale_factor = float(idx.get("scale_factor", scale_factor)) channels = int(idx.get("channels", channels)) has_split_components = all((root / name / "config.json").exists() for name in _COMPONENT_NAMES) if not has_split_components: missing = [name for name in _COMPONENT_NAMES if not (root / name / "config.json").exists()] raise FileNotFoundError( f"CRS-Diff split component export incomplete. Missing: {missing}. " "Expected split folders with config.json and weights." ) loaded = {name: load_component(root, name) for name in _COMPONENT_NAMES} crs_model = CRSModelWrapper( unet=loaded["unet"], vae=loaded["vae"], text_encoder=loaded["text_encoder"], local_adapter=loaded["local_adapter"], global_content_adapter=loaded["global_content_adapter"], global_text_adapter=loaded["global_text_adapter"], metadata_encoder=loaded["metadata_encoder"], channels=channels, ) return {"crs_model": crs_model, "scheduler": scheduler, "scale_factor": scale_factor}