Z-Image-to-LoRA / app.py
Alexander Bagus
Update app.py
1c8e73f
import gradio as gr
import numpy as np
import torch, random, json, spaces, time
from ulid import ULID
from diffsynth.pipelines.z_image import (
ModelConfig, ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode
)
from diffsynth.pipelines.z_image import ZImagePipeline as ZImagePipelineDs
from diffusers import ZImagePipeline
from safetensors.torch import save_file
import torch
from PIL import Image
from pathlib import Path
from huggingface_hub import snapshot_download
import glob
DTYPE = torch.bfloat16
MAX_SEED = np.iinfo(np.int32).max
MODELS_DIR = Path("./models")
def download_hf_models(output_dir: Path) -> dict:
output_dir.mkdir(parents=True, exist_ok=True)
models = [
{
"repo_id": "DiffSynth-Studio/General-Image-Encoders",
"description": "General Image Encoders (SigLIP2-G384, DINOv3-7B)",
"allow_patterns": None,
},
{
"repo_id": "Tongyi-MAI/Z-Image-Turbo",
"description": "Z-Image Turbo",
# "allow_patterns": [
# "text_encoder/*.safetensors",
# "vae/*.safetensors",
# "tokenizer/*",
# ]
"allow_patterns": None,
},
{
"repo_id": "Tongyi-MAI/Z-Image",
"description": "Z-Image base model (transformer)",
"allow_patterns": ["transformer/*.safetensors"],
# "allow_patterns": None,
},
{
"repo_id": "DiffSynth-Studio/Z-Image-i2L",
"description": "Z-Image-i2L (Image to LoRA model)",
"allow_patterns": ["*.safetensors"],
},
]
downloaded_paths = {}
for model in models:
repo_id = model["repo_id"]
local_dir = output_dir / repo_id
# Check if already downloaded
if local_dir.exists():
print(f" ✓ {repo_id} (already downloaded)")
downloaded_paths[repo_id] = local_dir
continue
print(f" 📥 Downloading {repo_id}...")
print(f" {model['description']}")
try:
result_path = snapshot_download(
repo_id=repo_id,
local_dir=str(local_dir),
allow_patterns=model["allow_patterns"],
local_dir_use_symlinks=False,
resume_download=True,
)
downloaded_paths[repo_id] = Path(result_path)
print(f" ✓ {repo_id}")
except Exception as e:
print(f" ❌ Error downloading {repo_id}: {e}")
raise
return downloaded_paths
def get_model_files(base_path: Path, pattern: str) -> list:
"""Get list of files matching a glob pattern."""
full_pattern = str(base_path / pattern)
files = sorted(glob.glob(full_pattern))
return files
downloaded_paths = download_hf_models(MODELS_DIR)
zimage_path = MODELS_DIR / "Tongyi-MAI" / "Z-Image"
zimage_transformer_files = get_model_files(zimage_path, "transformer/*.safetensors")
# Z-Image-Turbo
zimage_turbo_path = MODELS_DIR / "Tongyi-MAI" / "Z-Image-Turbo"
text_encoder_files = get_model_files(zimage_turbo_path, "text_encoder/*.safetensors")
vae_file = get_model_files(zimage_turbo_path, "vae/diffusion_pytorch_model.safetensors")
tokenizer_path = zimage_turbo_path / "tokenizer"
# General Image Encoders
encoders_path = MODELS_DIR / "DiffSynth-Studio" / "General-Image-Encoders"
siglip_file = get_model_files(encoders_path, "SigLIP2-G384/model.safetensors")
dino_file = get_model_files(encoders_path, "DINOv3-7B/model.safetensors")
# Z-Image-i2L from HuggingFace
zimage_i2l_path = MODELS_DIR / "DiffSynth-Studio" / "Z-Image-i2L"
zimage_i2l_file = get_model_files(zimage_i2l_path, "model.safetensors")
print(f" Z-Image transformer: {len(zimage_transformer_files)} file(s)")
print(f" Text encoder: {len(text_encoder_files)} file(s)")
print(f" VAE: {len(vae_file)} file(s)")
print(f" Tokenizer: {tokenizer_path}")
print(f" SigLIP2: {len(siglip_file)} file(s)")
print(f" DINOv3: {len(dino_file)} file(s)")
print(f" Z-Image-i2L: {len(zimage_i2l_file)} file(s)")
################
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cuda",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
model_configs = [
# All models from HuggingFace - use path= for local files
ModelConfig(path=zimage_transformer_files, **vram_config),
ModelConfig(path=text_encoder_files),
ModelConfig(path=vae_file),
ModelConfig(path=siglip_file),
ModelConfig(path=dino_file),
ModelConfig(path=zimage_i2l_file),
]
pipe_lora = ZImagePipelineDs.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=model_configs,
tokenizer_config=ModelConfig(path=str(tokenizer_path)),
)
pipe_imagen = ZImagePipeline.from_pretrained(
"./models/Tongyi-MAI/Z-Image-Turbo",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=False,
)
pipe_imagen.to("cuda")
@spaces.GPU(duration=120)
def generate_lora(
input_images,
progress=gr.Progress(track_tqdm=True),
):
ulid = str(ULID()).lower()[:12]
print(f"ulid: {ulid}")
if not input_images:
print("images are empty.")
return False
progress(0.1, desc="Processing images...")
print("progress: step 1")
# pil_images = [Image.open(filepath).convert("RGB") for filepath, _ in input_images]
pil_images = []
for img in input_images:
if isinstance(img, str):
pil_images.append(Image.open(img).convert("RGB"))
elif isinstance(img, tuple):
pil_images.append(Image.open(img[0]).convert("RGB"))
else:
pil_images.append(Image.fromarray(img).convert("RGB"))
progress(0.3, desc="Encoding images to LoRA...")
print("progress: step 2")
# Model inference
with torch.no_grad():
embs = ZImageUnit_Image2LoRAEncode().process(pipe_lora, image2lora_images=pil_images)
progress(0.7, desc="Decoding LoRA weights...")
print("progress: step 3")
lora = ZImageUnit_Image2LoRADecode().process(pipe_lora, **embs)["lora"]
progress(0.9, desc="Saving LoRA file...")
print("progress: step 4")
lora_name = f"{ulid}.safetensors"
lora_path = f"loras/{lora_name}"
progress(1.0, desc="Done!")
save_file(lora, lora_path)
return lora_name, gr.update(interactive=True, value=lora_path), gr.update(interactive=True)
@spaces.GPU
def generate_image(
lora_name,
prompt,
negative_prompt="blurry ugly bad",
width=1024,
height=1024,
seed=42,
randomize_seed=True,
guidance_scale = 4.0,
lora_strength = 1.25,
num_inference_steps=8,
progress=gr.Progress(track_tqdm=True),
):
pipe_imagen.unload_lora_weights()
pipe_imagen.load_lora_weights(
"loras",
weight_name=lora_name,
adapter_name="generated_lora"
)
pipe_imagen.set_adapters("generated_lora", adapter_weights=lora_strength)
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
output_image = pipe_imagen(
prompt=prompt,
negative_prompt=negative_prompt if negative_prompt.strip() else None,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
guidance_scale=guidance_scale
).images[0]
return output_image, seed
def read_file(path: str) -> str:
with open(path, 'r', encoding='utf-8') as f:
content = f.read()
return content
css = """
#col-container {
margin: 0 auto;
max-width: 960px;
}
h3{
text-align: center;
display:block;
}
"""
with open('examples/0_examples.json', 'r') as file: examples = json.load(file)
print(examples)
with gr.Blocks() as demo:
with gr.Column(elem_id="col-container"):
with gr.Column():
gr.HTML(read_file("static/header.html"))
with gr.Row():
with gr.Column():
input_images = gr.Gallery(
label="Input images",
file_types=["image"],
show_label=False,
elem_id="gallery",
columns=2,
object_fit="cover",
height=300)
lora_button = gr.Button("Generate LoRA", variant="primary")
with gr.Column():
lora_name = gr.Textbox(label="Generated LoRA path",lines=2, interactive=False)
lora_download = gr.DownloadButton(label=f"Download LoRA", interactive=False)
with gr.Column(elem_id='imagen-container') as imagen_container:
gr.Markdown("### After your LoRA is ready, you can try generate image here.")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
show_label=False,
lines=2,
placeholder="Enter your prompt",
value="a man in a fishing boat.",
container=False,
)
imagen_button = gr.Button("Generate Image", variant="primary", interactive=False)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Textbox(
label="Negative prompt",
lines=2,
container=False,
placeholder="Enter your negative prompt"
)
num_inference_steps = gr.Slider(
label="Steps",
minimum=4,
maximum=12,
step=1,
value=9,
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=512,
maximum=2048,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=512,
maximum=2048,
step=32,
value=1024,
)
with gr.Row():
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=3.0,
step=0.1,
value=0.0,
)
lora_strength = gr.Slider(
label="Lora Strength",
minimum=0.0,
maximum=5.0,
step=0.05,
value=2.00,
)
with gr.Column():
output_image = gr.Image(label="Generated image", show_label=False)
gr.Examples(
examples=examples,
inputs=[input_images],
)
gr.Markdown(read_file("static/footer.md"))
lora_button.click(
fn=generate_lora,
inputs=[
input_images
],
outputs=[lora_name, lora_download, imagen_button],
)
imagen_button.click(
fn=generate_image,
inputs=[
lora_name,
prompt,
negative_prompt,
width,
height,
seed,
randomize_seed,
guidance_scale,
lora_strength,
num_inference_steps,
],
outputs=[output_image, seed],
)
if __name__ == "__main__":
demo.launch(mcp_server=True, css=css)