| import argparse |
| import os |
| import json |
| import copy |
| import os.path as osp |
|
|
| import torch |
| from diffusers import UNet2DConditionModel, AutoencoderKL |
| from diffusers.models.attention import BasicTransformerBlock |
| from peft import LoraConfig |
| from peft.utils import set_peft_model_state_dict |
| from transformers import PretrainedConfig |
|
|
| from diffusers import DPMSolverMultistepScheduler |
|
|
| from glyph_sdxl.utils import ( |
| parse_config, |
| UNET_CKPT_NAME, |
| huggingface_cache_dir, |
| load_byt5_and_byt5_tokenizer, |
| BYT5_MAPPER_CKPT_NAME, |
| INSERTED_ATTN_CKPT_NAME, |
| BYT5_CKPT_NAME, |
| MultilingualPromptFormat, |
| ) |
| from glyph_sdxl.custom_diffusers import ( |
| StableDiffusionGlyphXLPipeline, |
| CrossAttnInsertBasicTransformerBlock, |
| ) |
| from glyph_sdxl.modules import T5EncoderBlockByT5Mapper |
|
|
| byt5_mapper_dict = [T5EncoderBlockByT5Mapper] |
| byt5_mapper_dict = {mapper.__name__: mapper for mapper in byt5_mapper_dict} |
|
|
|
|
| def import_model_class_from_model_name_or_path( |
| pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder", |
| ): |
| text_encoder_config = PretrainedConfig.from_pretrained( |
| pretrained_model_name_or_path, |
| subfolder=subfolder, |
| revision=revision, |
| ) |
| model_class = text_encoder_config.architectures[0] |
|
|
| if model_class == "CLIPTextModel": |
| from transformers import CLIPTextModel |
|
|
| return CLIPTextModel |
| elif model_class == "CLIPTextModelWithProjection": |
| from transformers import CLIPTextModelWithProjection |
|
|
| return CLIPTextModelWithProjection |
| else: |
| raise ValueError(f"{model_class} is not supported.") |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("config_dir", type=str) |
| parser.add_argument("ckpt_dir", type=str) |
| parser.add_argument("ann_path", type=str, default='examples/shower.json') |
| parser.add_argument("--out_folder", type=str, default='None') |
| parser.add_argument("--device", type=str, default='cuda') |
| parser.add_argument("--sampler", type=str, choices=['euler', 'dpm']) |
| parser.add_argument("--cfg", type=float, default=5.0) |
| args = parser.parse_args() |
| |
| config = parse_config(args.config_dir) |
| |
| text_encoder_cls_one = import_model_class_from_model_name_or_path( |
| config.pretrained_model_name_or_path, config.revision, |
| ) |
| text_encoder_cls_two = import_model_class_from_model_name_or_path( |
| config.pretrained_model_name_or_path, config.revision, subfolder="text_encoder_2", |
| ) |
| text_encoder_one = text_encoder_cls_one.from_pretrained( |
| config.pretrained_model_name_or_path, subfolder="text_encoder", revision=config.revision, |
| cache_dir=huggingface_cache_dir, |
| ) |
| text_encoder_two = text_encoder_cls_two.from_pretrained( |
| config.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=config.revision, |
| cache_dir=huggingface_cache_dir, |
| ) |
|
|
| unet = UNet2DConditionModel.from_pretrained( |
| config.pretrained_model_name_or_path, |
| subfolder="unet", |
| revision=config.revision, |
| cache_dir=huggingface_cache_dir, |
| ) |
| |
| vae_path = ( |
| config.pretrained_model_name_or_path |
| if config.pretrained_vae_model_name_or_path is None |
| else config.pretrained_vae_model_name_or_path |
| ) |
| vae = AutoencoderKL.from_pretrained( |
| vae_path, subfolder="vae" if config.pretrained_vae_model_name_or_path is None else None, |
| revision=config.revision, |
| cache_dir=huggingface_cache_dir, |
| ) |
|
|
| byt5_model, byt5_tokenizer = load_byt5_and_byt5_tokenizer( |
| **config.byt5_config, |
| huggingface_cache_dir=huggingface_cache_dir, |
| ) |
|
|
| inference_dtype = torch.float32 |
| if config.inference_dtype == "fp16": |
| inference_dtype = torch.float16 |
| elif config.inference_dtype == "bf16": |
| inference_dtype = torch.bfloat16 |
|
|
| if config.pretrained_vae_model_name_or_path is None: |
| vae.to(args.device, dtype=torch.float32) |
| else: |
| vae.to(args.device, dtype=inference_dtype) |
| text_encoder_one.to(args.device, dtype=inference_dtype) |
| text_encoder_two.to(args.device, dtype=inference_dtype) |
| byt5_model.to(args.device) |
| unet.to(args.device, dtype=inference_dtype) |
|
|
| inserted_new_modules_para_set = set() |
| for name, module in unet.named_modules(): |
| if isinstance(module, BasicTransformerBlock) and name in config.attn_block_to_modify: |
| parent_module = unet |
| for n in name.split(".")[:-1]: |
| parent_module = getattr(parent_module, n) |
| new_block = CrossAttnInsertBasicTransformerBlock.from_transformer_block( |
| module, |
| byt5_model.config.d_model if config.byt5_mapper_config.sdxl_channels is None else config.byt5_mapper_config.sdxl_channels, |
| ) |
| new_block.requires_grad_(False) |
| for inserted_module_name, inserted_module in zip( |
| new_block.get_inserted_modules_names(), |
| new_block.get_inserted_modules() |
| ): |
| inserted_module.requires_grad_(True) |
| for para_name, para in inserted_module.named_parameters(): |
| para_key = name + '.' + inserted_module_name + '.' + para_name |
| assert para_key not in inserted_new_modules_para_set |
| inserted_new_modules_para_set.add(para_key) |
| for origin_module in new_block.get_origin_modules(): |
| origin_module.to(args.device, dtype=inference_dtype) |
| parent_module.register_module(name.split(".")[-1], new_block) |
| print(f"inserted cross attn block to {name}") |
|
|
| byt5_mapper = byt5_mapper_dict[config.byt5_mapper_type]( |
| byt5_model.config, |
| **config.byt5_mapper_config, |
| ) |
|
|
| unet_lora_target_modules = [ |
| "attn1.to_k", "attn1.to_q", "attn1.to_v", "attn1.to_out.0", |
| "attn2.to_k", "attn2.to_q", "attn2.to_v", "attn2.to_out.0", |
| ] |
| unet_lora_config = LoraConfig( |
| r=config.unet_lora_rank, |
| lora_alpha=config.unet_lora_rank, |
| init_lora_weights="gaussian", |
| target_modules=unet_lora_target_modules, |
| ) |
| unet.add_adapter(unet_lora_config) |
| |
| unet_lora_layers_para = torch.load(osp.join(args.ckpt_dir, UNET_CKPT_NAME), map_location='cpu') |
| incompatible_keys = set_peft_model_state_dict(unet, unet_lora_layers_para, adapter_name="default") |
| if getattr(incompatible_keys, 'unexpected_keys', []) == []: |
| print(f"loaded unet_lora_layers_para") |
| else: |
| print(f"unet_lora_layers has unexpected_keys: {getattr(incompatible_keys, 'unexpected_keys', None)}") |
| |
| inserted_attn_module_paras = torch.load(osp.join(args.ckpt_dir, INSERTED_ATTN_CKPT_NAME), map_location='cpu') |
| missing_keys, unexpected_keys = unet.load_state_dict(inserted_attn_module_paras, strict=False) |
| assert len(unexpected_keys) == 0, unexpected_keys |
| |
| byt5_mapper_para = torch.load(osp.join(args.ckpt_dir, BYT5_MAPPER_CKPT_NAME), map_location='cpu') |
| byt5_mapper.load_state_dict(byt5_mapper_para) |
| |
| byt5_model_para = torch.load(osp.join(args.ckpt_dir, BYT5_CKPT_NAME), map_location='cpu') |
| byt5_model.load_state_dict(byt5_model_para) |
|
|
| pipeline = StableDiffusionGlyphXLPipeline.from_pretrained( |
| config.pretrained_model_name_or_path, |
| vae=vae, |
| text_encoder=text_encoder_one, |
| text_encoder_2=text_encoder_two, |
| byt5_text_encoder=byt5_model, |
| byt5_tokenizer=byt5_tokenizer, |
| byt5_mapper=byt5_mapper, |
| unet=unet, |
| byt5_max_length=config.byt5_max_length, |
| revision=config.revision, |
| torch_dtype=inference_dtype, |
| safety_checker=None, |
| cache_dir=huggingface_cache_dir, |
| ) |
|
|
| if args.sampler == 'dpm': |
| pipeline.scheduler = DPMSolverMultistepScheduler.from_pretrained( |
| config.pretrained_model_name_or_path, |
| subfolder="scheduler", |
| use_karras_sigmas=True, |
| ) |
|
|
| pipeline = pipeline.to(args.device) |
|
|
| with open(args.ann_path, 'r') as f: |
| ann = json.load(f) |
| |
| os.makedirs(args.out_folder, exist_ok=True) |
|
|
| prompt_format = MultilingualPromptFormat() |
|
|
| texts = copy.deepcopy(ann['texts']) |
| bboxes = copy.deepcopy(ann['bbox']) |
| styles = copy.deepcopy(ann['styles']) |
|
|
| text_prompt = prompt_format.format_prompt(texts, styles) |
|
|
| if 'seed' not in ann: |
| generator = torch.Generator(device=args.device) |
| else: |
| generator = torch.Generator(device=args.device).manual_seed(ann['seed']) |
|
|
| with torch.cuda.amp.autocast(): |
| image = pipeline( |
| prompt=ann['bg_prompt'], |
| text_prompt=text_prompt, |
| texts=texts, |
| bboxes=bboxes, |
| num_inference_steps=50, |
| generator=generator, |
| text_attn_mask=None, |
| guidance_scale=args.cfg, |
| ).images[0] |
| image.save(f'{args.out_folder}/result.png') |
|
|