| import gradio as gr
|
|
|
| from gradio_toggle import Toggle
|
| import argparse
|
| import json
|
| import os
|
| import random
|
| from datetime import datetime
|
| from pathlib import Path
|
| from diffusers.utils import logging
|
|
|
| import imageio
|
| import numpy as np
|
| import safetensors.torch
|
| import torch
|
| import torch.nn.functional as F
|
| from PIL import Image
|
| from transformers import T5EncoderModel, T5Tokenizer
|
| import tempfile
|
| from ltx_video.models.autoencoders.causal_video_autoencoder import (
|
| CausalVideoAutoencoder,
|
| )
|
| from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
|
| from ltx_video.models.transformers.transformer3d import Transformer3DModel
|
| from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
|
| from ltx_video.schedulers.rf import RectifiedFlowScheduler
|
| from ltx_video.utils.conditioning_method import ConditioningMethod
|
| from torchao.quantization import quantize_, int8_weight_only
|
|
|
| MAX_HEIGHT = 720
|
| MAX_WIDTH = 1280
|
| MAX_NUM_FRAMES = 257
|
|
|
|
|
| def load_vae(vae_dir, int8=False):
|
| vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors"
|
| vae_config_path = vae_dir / "config.json"
|
| with open(vae_config_path, "r") as f:
|
| vae_config = json.load(f)
|
| vae = CausalVideoAutoencoder.from_config(vae_config)
|
| vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
| vae.load_state_dict(vae_state_dict)
|
|
|
| vae = vae.to('cpu')
|
| if int8:
|
| print("vae - quantization = true")
|
| quantize_(vae, int8_weight_only())
|
| return vae
|
|
|
|
|
| def load_unet(unet_dir, int8=False):
|
| unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors"
|
| unet_config_path = unet_dir / "config.json"
|
| transformer_config = Transformer3DModel.load_config(unet_config_path)
|
| transformer = Transformer3DModel.from_config(transformer_config)
|
| unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
|
| transformer.load_state_dict(unet_state_dict, strict=True)
|
|
|
| transformer = transformer.to('cpu')
|
| if int8:
|
| print("unet - quantization = true")
|
| quantize_(transformer, int8_weight_only())
|
| return transformer
|
|
|
|
|
| def load_scheduler(scheduler_dir):
|
| scheduler_config_path = scheduler_dir / "scheduler_config.json"
|
| scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
|
| return RectifiedFlowScheduler.from_config(scheduler_config)
|
|
|
|
|
| def load_image_to_tensor_with_resize_and_crop(image_path, target_height=512, target_width=768):
|
| image = Image.open(image_path).convert("RGB")
|
| input_width, input_height = image.size
|
| aspect_ratio_target = target_width / target_height
|
| aspect_ratio_frame = input_width / input_height
|
| if aspect_ratio_frame > aspect_ratio_target:
|
| new_width = int(input_height * aspect_ratio_target)
|
| new_height = input_height
|
| x_start = (input_width - new_width) // 2
|
| y_start = 0
|
| else:
|
| new_width = input_width
|
| new_height = int(input_width / aspect_ratio_target)
|
| x_start = 0
|
| y_start = (input_height - new_height) // 2
|
|
|
| image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
|
| image = image.resize((target_width, target_height))
|
| frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float()
|
| frame_tensor = (frame_tensor / 127.5) - 1.0
|
|
|
| return frame_tensor.unsqueeze(0).unsqueeze(2)
|
|
|
|
|
| def calculate_padding(
|
| source_height: int, source_width: int, target_height: int, target_width: int
|
| ) -> tuple[int, int, int, int]:
|
|
|
|
|
| pad_height = target_height - source_height
|
| pad_width = target_width - source_width
|
|
|
|
|
| pad_top = pad_height // 2
|
| pad_bottom = pad_height - pad_top
|
| pad_left = pad_width // 2
|
| pad_right = pad_width - pad_left
|
|
|
|
|
|
|
| padding = (pad_left, pad_right, pad_top, pad_bottom)
|
| return padding
|
|
|
|
|
| def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
|
|
|
| clean_text = "".join(
|
| char.lower() for char in text if char.isalpha() or char.isspace()
|
| )
|
|
|
|
|
| words = clean_text.split()
|
|
|
|
|
| result = []
|
| current_length = 0
|
|
|
| for word in words:
|
|
|
| new_length = current_length + len(word)
|
|
|
| if new_length <= max_len:
|
| result.append(word)
|
| current_length += len(word)
|
| else:
|
| break
|
|
|
| return "-".join(result)
|
|
|
|
|
|
|
| def get_unique_filename(
|
| base: str,
|
| ext: str,
|
| prompt: str,
|
| seed: int,
|
| resolution: tuple[int, int, int],
|
| dir: Path,
|
| endswith=None,
|
| index_range=1000,
|
| ) -> Path:
|
| base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
|
| for i in range(index_range):
|
| filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
|
| if not os.path.exists(filename):
|
| return filename
|
| raise FileExistsError(
|
| f"Could not find a unique filename after {index_range} attempts."
|
| )
|
|
|
|
|
| def seed_everething(seed: int):
|
| random.seed(seed)
|
| np.random.seed(seed)
|
| torch.manual_seed(seed)
|
|
|
|
|
| def main(
|
| img2vid_image="",
|
| prompt="",
|
| txt2vid_analytics_toggle=False,
|
| negative_prompt="",
|
| frame_rate=25,
|
| seed=0,
|
| num_inference_steps=30,
|
| guidance_scale=3,
|
| height=512,
|
| width=768,
|
| num_frames=121,
|
| progress=gr.Progress(),
|
| ):
|
|
|
| logger = logging.get_logger(__name__)
|
|
|
| args = {
|
| "ckpt_dir": "Lightricks/LTX-Video",
|
| "num_inference_steps": num_inference_steps,
|
| "guidance_scale": guidance_scale,
|
| "height": height,
|
| "width": width,
|
| "num_frames": num_frames,
|
| "frame_rate": frame_rate,
|
| "prompt": prompt,
|
| "negative_prompt": negative_prompt,
|
| "seed": 0,
|
| "output_path": os.path.join(tempfile.gettempdir(), "gradio"),
|
| "num_images_per_prompt": 1,
|
| "input_image_path": img2vid_image,
|
| "input_video_path": "",
|
| "bfloat16": True,
|
| "disable_load_needed_only": False
|
| }
|
| logger.warning(f"Running generation with arguments: {args}")
|
|
|
| seed_everething(args['seed'])
|
|
|
| output_dir = (
|
| Path(args['output_path'])
|
| if args['output_path']
|
| else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
|
| )
|
| output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| if args['input_image_path']:
|
| media_items_prepad = load_image_to_tensor_with_resize_and_crop(
|
| args['input_image_path'], args['height'], args['width']
|
| )
|
| else:
|
| media_items_prepad = None
|
|
|
| height = args['height'] if args['height'] else media_items_prepad.shape[-2]
|
| width = args['width'] if args['width'] else media_items_prepad.shape[-1]
|
| num_frames = args['num_frames']
|
|
|
| if height > MAX_HEIGHT or width > MAX_WIDTH or num_frames > MAX_NUM_FRAMES:
|
| logger.warning(
|
| f"Input resolution or number of frames {height}x{width}x{num_frames} is too big, it is suggested to use the resolution below {MAX_HEIGHT}x{MAX_WIDTH}x{MAX_NUM_FRAMES}."
|
| )
|
|
|
|
|
| height_padded = ((height - 1) // 32 + 1) * 32
|
| width_padded = ((width - 1) // 32 + 1) * 32
|
| num_frames_padded = ((num_frames - 2) // 8 + 1) * 8 + 1
|
|
|
| padding = calculate_padding(height, width, height_padded, width_padded)
|
|
|
| logger.warning(
|
| f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}"
|
| )
|
|
|
| if media_items_prepad is not None:
|
| media_items = F.pad(
|
| media_items_prepad, padding, mode="constant", value=-1
|
| )
|
| else:
|
| media_items = None
|
|
|
|
|
| vae = load_vae(Path(args['ckpt_dir']) / "vae", txt2vid_analytics_toggle)
|
| unet = load_unet(Path(args['ckpt_dir']) / "unet", txt2vid_analytics_toggle)
|
| scheduler = load_scheduler(Path(args['ckpt_dir']) / "scheduler")
|
| patchifier = SymmetricPatchifier(patch_size=1)
|
| text_encoder = T5EncoderModel.from_pretrained(
|
| "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
|
| ).to('cpu')
|
|
|
| tokenizer = T5Tokenizer.from_pretrained(
|
| "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
|
| )
|
|
|
|
|
| submodel_dict = {
|
| "transformer": unet,
|
| "patchifier": patchifier,
|
| "text_encoder": text_encoder,
|
| "tokenizer": tokenizer,
|
| "scheduler": scheduler,
|
| "vae": vae,
|
| }
|
|
|
| pipeline = LTXVideoPipeline(**submodel_dict)
|
| pipeline = pipeline.to('cpu')
|
|
|
|
|
| sample = {
|
| "prompt": args['prompt'],
|
| "prompt_attention_mask": None,
|
| "negative_prompt": args['negative_prompt'],
|
| "negative_prompt_attention_mask": None,
|
| "media_items": media_items,
|
| }
|
|
|
| generator = torch.Generator(device="cpu").manual_seed(args['seed'])
|
|
|
| images = pipeline(
|
| num_inference_steps=args['num_inference_steps'],
|
| num_images_per_prompt=args['num_images_per_prompt'],
|
| guidance_scale=args['guidance_scale'],
|
| generator=generator,
|
| output_type="pt",
|
| callback_on_step_end=None,
|
| height=height_padded,
|
| width=width_padded,
|
| num_frames=num_frames_padded,
|
| frame_rate=args['frame_rate'],
|
| **sample,
|
| is_video=True,
|
| vae_per_channel_normalize=True,
|
| conditioning_method=(
|
| ConditioningMethod.FIRST_FRAME
|
| if media_items is not None
|
| else ConditioningMethod.UNCONDITIONAL
|
| ),
|
| mixed_precision=not args['bfloat16'],
|
| load_needed_only=not args['disable_load_needed_only']
|
| ).images
|
|
|
|
|
|
|
|
|
|
|