Spaces:
Paused
Paused
| import gradio as gr | |
| import numpy as np | |
| import torch, random, json, spaces, time | |
| from ulid import ULID | |
| from diffsynth.pipelines.z_image import ( | |
| ModelConfig, ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode | |
| ) | |
| from diffsynth.pipelines.z_image import ZImagePipeline as ZImagePipelineDs | |
| from diffusers import ZImagePipeline | |
| from safetensors.torch import save_file | |
| import torch | |
| from PIL import Image | |
| from pathlib import Path | |
| from huggingface_hub import snapshot_download | |
| import glob | |
| DTYPE = torch.bfloat16 | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MODELS_DIR = Path("./models") | |
| def download_hf_models(output_dir: Path) -> dict: | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| models = [ | |
| { | |
| "repo_id": "DiffSynth-Studio/General-Image-Encoders", | |
| "description": "General Image Encoders (SigLIP2-G384, DINOv3-7B)", | |
| "allow_patterns": None, | |
| }, | |
| { | |
| "repo_id": "Tongyi-MAI/Z-Image-Turbo", | |
| "description": "Z-Image Turbo", | |
| # "allow_patterns": [ | |
| # "text_encoder/*.safetensors", | |
| # "vae/*.safetensors", | |
| # "tokenizer/*", | |
| # ] | |
| "allow_patterns": None, | |
| }, | |
| { | |
| "repo_id": "Tongyi-MAI/Z-Image", | |
| "description": "Z-Image base model (transformer)", | |
| "allow_patterns": ["transformer/*.safetensors"], | |
| # "allow_patterns": None, | |
| }, | |
| { | |
| "repo_id": "DiffSynth-Studio/Z-Image-i2L", | |
| "description": "Z-Image-i2L (Image to LoRA model)", | |
| "allow_patterns": ["*.safetensors"], | |
| }, | |
| ] | |
| downloaded_paths = {} | |
| for model in models: | |
| repo_id = model["repo_id"] | |
| local_dir = output_dir / repo_id | |
| # Check if already downloaded | |
| if local_dir.exists(): | |
| print(f" ✓ {repo_id} (already downloaded)") | |
| downloaded_paths[repo_id] = local_dir | |
| continue | |
| print(f" 📥 Downloading {repo_id}...") | |
| print(f" {model['description']}") | |
| try: | |
| result_path = snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=str(local_dir), | |
| allow_patterns=model["allow_patterns"], | |
| local_dir_use_symlinks=False, | |
| resume_download=True, | |
| ) | |
| downloaded_paths[repo_id] = Path(result_path) | |
| print(f" ✓ {repo_id}") | |
| except Exception as e: | |
| print(f" ❌ Error downloading {repo_id}: {e}") | |
| raise | |
| return downloaded_paths | |
| def get_model_files(base_path: Path, pattern: str) -> list: | |
| """Get list of files matching a glob pattern.""" | |
| full_pattern = str(base_path / pattern) | |
| files = sorted(glob.glob(full_pattern)) | |
| return files | |
| downloaded_paths = download_hf_models(MODELS_DIR) | |
| zimage_path = MODELS_DIR / "Tongyi-MAI" / "Z-Image" | |
| zimage_transformer_files = get_model_files(zimage_path, "transformer/*.safetensors") | |
| # Z-Image-Turbo | |
| zimage_turbo_path = MODELS_DIR / "Tongyi-MAI" / "Z-Image-Turbo" | |
| text_encoder_files = get_model_files(zimage_turbo_path, "text_encoder/*.safetensors") | |
| vae_file = get_model_files(zimage_turbo_path, "vae/diffusion_pytorch_model.safetensors") | |
| tokenizer_path = zimage_turbo_path / "tokenizer" | |
| # General Image Encoders | |
| encoders_path = MODELS_DIR / "DiffSynth-Studio" / "General-Image-Encoders" | |
| siglip_file = get_model_files(encoders_path, "SigLIP2-G384/model.safetensors") | |
| dino_file = get_model_files(encoders_path, "DINOv3-7B/model.safetensors") | |
| # Z-Image-i2L from HuggingFace | |
| zimage_i2l_path = MODELS_DIR / "DiffSynth-Studio" / "Z-Image-i2L" | |
| zimage_i2l_file = get_model_files(zimage_i2l_path, "model.safetensors") | |
| print(f" Z-Image transformer: {len(zimage_transformer_files)} file(s)") | |
| print(f" Text encoder: {len(text_encoder_files)} file(s)") | |
| print(f" VAE: {len(vae_file)} file(s)") | |
| print(f" Tokenizer: {tokenizer_path}") | |
| print(f" SigLIP2: {len(siglip_file)} file(s)") | |
| print(f" DINOv3: {len(dino_file)} file(s)") | |
| print(f" Z-Image-i2L: {len(zimage_i2l_file)} file(s)") | |
| ################ | |
| vram_config = { | |
| "offload_dtype": torch.bfloat16, | |
| "offload_device": "cuda", | |
| "onload_dtype": torch.bfloat16, | |
| "onload_device": "cuda", | |
| "preparing_dtype": torch.bfloat16, | |
| "preparing_device": "cuda", | |
| "computation_dtype": torch.bfloat16, | |
| "computation_device": "cuda", | |
| } | |
| model_configs = [ | |
| # All models from HuggingFace - use path= for local files | |
| ModelConfig(path=zimage_transformer_files, **vram_config), | |
| ModelConfig(path=text_encoder_files), | |
| ModelConfig(path=vae_file), | |
| ModelConfig(path=siglip_file), | |
| ModelConfig(path=dino_file), | |
| ModelConfig(path=zimage_i2l_file), | |
| ] | |
| pipe_lora = ZImagePipelineDs.from_pretrained( | |
| torch_dtype=torch.bfloat16, | |
| device="cuda", | |
| model_configs=model_configs, | |
| tokenizer_config=ModelConfig(path=str(tokenizer_path)), | |
| ) | |
| pipe_imagen = ZImagePipeline.from_pretrained( | |
| "./models/Tongyi-MAI/Z-Image-Turbo", | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=False, | |
| ) | |
| pipe_imagen.to("cuda") | |
| def generate_lora( | |
| input_images, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| ulid = str(ULID()).lower()[:12] | |
| print(f"ulid: {ulid}") | |
| if not input_images: | |
| print("images are empty.") | |
| return False | |
| progress(0.1, desc="Processing images...") | |
| print("progress: step 1") | |
| # pil_images = [Image.open(filepath).convert("RGB") for filepath, _ in input_images] | |
| pil_images = [] | |
| for img in input_images: | |
| if isinstance(img, str): | |
| pil_images.append(Image.open(img).convert("RGB")) | |
| elif isinstance(img, tuple): | |
| pil_images.append(Image.open(img[0]).convert("RGB")) | |
| else: | |
| pil_images.append(Image.fromarray(img).convert("RGB")) | |
| progress(0.3, desc="Encoding images to LoRA...") | |
| print("progress: step 2") | |
| # Model inference | |
| with torch.no_grad(): | |
| embs = ZImageUnit_Image2LoRAEncode().process(pipe_lora, image2lora_images=pil_images) | |
| progress(0.7, desc="Decoding LoRA weights...") | |
| print("progress: step 3") | |
| lora = ZImageUnit_Image2LoRADecode().process(pipe_lora, **embs)["lora"] | |
| progress(0.9, desc="Saving LoRA file...") | |
| print("progress: step 4") | |
| lora_name = f"{ulid}.safetensors" | |
| lora_path = f"loras/{lora_name}" | |
| progress(1.0, desc="Done!") | |
| save_file(lora, lora_path) | |
| return lora_name, gr.update(interactive=True, value=lora_path), gr.update(interactive=True) | |
| def generate_image( | |
| lora_name, | |
| prompt, | |
| negative_prompt="blurry ugly bad", | |
| width=1024, | |
| height=1024, | |
| seed=42, | |
| randomize_seed=True, | |
| guidance_scale = 4.0, | |
| lora_strength = 1.25, | |
| num_inference_steps=8, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| pipe_imagen.unload_lora_weights() | |
| pipe_imagen.load_lora_weights( | |
| "loras", | |
| weight_name=lora_name, | |
| adapter_name="generated_lora" | |
| ) | |
| pipe_imagen.set_adapters("generated_lora", adapter_weights=lora_strength) | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| generator = torch.Generator().manual_seed(seed) | |
| output_image = pipe_imagen( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt if negative_prompt.strip() else None, | |
| num_inference_steps=num_inference_steps, | |
| width=width, | |
| height=height, | |
| generator=generator, | |
| guidance_scale=guidance_scale | |
| ).images[0] | |
| return output_image, seed | |
| def read_file(path: str) -> str: | |
| with open(path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| return content | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 960px; | |
| } | |
| h3{ | |
| text-align: center; | |
| display:block; | |
| } | |
| """ | |
| with open('examples/0_examples.json', 'r') as file: examples = json.load(file) | |
| print(examples) | |
| with gr.Blocks() as demo: | |
| with gr.Column(elem_id="col-container"): | |
| with gr.Column(): | |
| gr.HTML(read_file("static/header.html")) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_images = gr.Gallery( | |
| label="Input images", | |
| file_types=["image"], | |
| show_label=False, | |
| elem_id="gallery", | |
| columns=2, | |
| object_fit="cover", | |
| height=300) | |
| lora_button = gr.Button("Generate LoRA", variant="primary") | |
| with gr.Column(): | |
| lora_name = gr.Textbox(label="Generated LoRA path",lines=2, interactive=False) | |
| lora_download = gr.DownloadButton(label=f"Download LoRA", interactive=False) | |
| with gr.Column(elem_id='imagen-container') as imagen_container: | |
| gr.Markdown("### After your LoRA is ready, you can try generate image here.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| show_label=False, | |
| lines=2, | |
| placeholder="Enter your prompt", | |
| value="a man in a fishing boat.", | |
| container=False, | |
| ) | |
| imagen_button = gr.Button("Generate Image", variant="primary", interactive=False) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| negative_prompt = gr.Textbox( | |
| label="Negative prompt", | |
| lines=2, | |
| container=False, | |
| placeholder="Enter your negative prompt" | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Steps", | |
| minimum=4, | |
| maximum=12, | |
| step=1, | |
| value=9, | |
| ) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| label="Width", | |
| minimum=512, | |
| maximum=2048, | |
| step=32, | |
| value=1024, | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| minimum=512, | |
| maximum=2048, | |
| step=32, | |
| value=1024, | |
| ) | |
| with gr.Row(): | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=42, | |
| ) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| guidance_scale = gr.Slider( | |
| label="Guidance scale", | |
| minimum=0.0, | |
| maximum=3.0, | |
| step=0.1, | |
| value=0.0, | |
| ) | |
| lora_strength = gr.Slider( | |
| label="Lora Strength", | |
| minimum=0.0, | |
| maximum=5.0, | |
| step=0.05, | |
| value=2.00, | |
| ) | |
| with gr.Column(): | |
| output_image = gr.Image(label="Generated image", show_label=False) | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[input_images], | |
| ) | |
| gr.Markdown(read_file("static/footer.md")) | |
| lora_button.click( | |
| fn=generate_lora, | |
| inputs=[ | |
| input_images | |
| ], | |
| outputs=[lora_name, lora_download, imagen_button], | |
| ) | |
| imagen_button.click( | |
| fn=generate_image, | |
| inputs=[ | |
| lora_name, | |
| prompt, | |
| negative_prompt, | |
| width, | |
| height, | |
| seed, | |
| randomize_seed, | |
| guidance_scale, | |
| lora_strength, | |
| num_inference_steps, | |
| ], | |
| outputs=[output_image, seed], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(mcp_server=True, css=css) | |