| import io
|
| import nodes
|
| import node_helpers
|
| import torch
|
| import comfy.model_management
|
| import comfy.model_sampling
|
| import comfy.utils
|
| import math
|
| import numpy as np
|
| import av
|
| from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
|
|
| class EmptyLTXVLatentVideo:
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
| "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
| "length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
| "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
| RETURN_TYPES = ("LATENT",)
|
| FUNCTION = "generate"
|
|
|
| CATEGORY = "latent/video/ltxv"
|
|
|
| def generate(self, width, height, length, batch_size=1):
|
| latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
| return ({"samples": latent}, )
|
|
|
|
|
| class LTXVImgToVideo:
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| return {"required": {"positive": ("CONDITIONING", ),
|
| "negative": ("CONDITIONING", ),
|
| "vae": ("VAE",),
|
| "image": ("IMAGE",),
|
| "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
| "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
| "length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
| "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
| }}
|
|
|
| RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
| RETURN_NAMES = ("positive", "negative", "latent")
|
|
|
| CATEGORY = "conditioning/video_models"
|
| FUNCTION = "generate"
|
|
|
| def generate(self, positive, negative, image, vae, width, height, length, batch_size):
|
| pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
| encode_pixels = pixels[:, :, :, :3]
|
| t = vae.encode(encode_pixels)
|
|
|
| latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
| latent[:, :, :t.shape[2]] = t
|
|
|
| conditioning_latent_frames_mask = torch.ones(
|
| (batch_size, 1, latent.shape[2], 1, 1),
|
| dtype=torch.float32,
|
| device=latent.device,
|
| )
|
| conditioning_latent_frames_mask[:, :, :t.shape[2]] = 0
|
|
|
| return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, )
|
|
|
|
|
| def conditioning_get_any_value(conditioning, key, default=None):
|
| for t in conditioning:
|
| if key in t[1]:
|
| return t[1][key]
|
| return default
|
|
|
|
|
| def get_noise_mask(latent):
|
| noise_mask = latent.get("noise_mask", None)
|
| latent_image = latent["samples"]
|
| if noise_mask is None:
|
| batch_size, _, latent_length, _, _ = latent_image.shape
|
| noise_mask = torch.ones(
|
| (batch_size, 1, latent_length, 1, 1),
|
| dtype=torch.float32,
|
| device=latent_image.device,
|
| )
|
| else:
|
| noise_mask = noise_mask.clone()
|
| return noise_mask
|
|
|
| def get_keyframe_idxs(cond):
|
| keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
|
| if keyframe_idxs is None:
|
| return None, 0
|
| num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0]
|
| return keyframe_idxs, num_keyframes
|
|
|
| class LTXVAddGuide:
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| return {"required": {"positive": ("CONDITIONING", ),
|
| "negative": ("CONDITIONING", ),
|
| "vae": ("VAE",),
|
| "latent": ("LATENT",),
|
| "image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames."
|
| "If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}),
|
| "frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999,
|
| "tooltip": "Frame index to start the conditioning at. For single-frame images or "
|
| "videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ "
|
| "frames, frame_idx must be divisible by 8, otherwise it will be rounded down to "
|
| "the nearest multiple of 8. Negative values are counted from the end of the video."}),
|
| "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
| }
|
| }
|
|
|
| RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
| RETURN_NAMES = ("positive", "negative", "latent")
|
|
|
| CATEGORY = "conditioning/video_models"
|
| FUNCTION = "generate"
|
|
|
| def __init__(self):
|
| self._num_prefix_frames = 2
|
| self._patchifier = SymmetricPatchifier(1)
|
|
|
| def encode(self, vae, latent_width, latent_height, images, scale_factors):
|
| time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
|
| images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
|
| pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
|
| encode_pixels = pixels[:, :, :, :3]
|
| t = vae.encode(encode_pixels)
|
| return encode_pixels, t
|
|
|
| def get_latent_index(self, cond, latent_length, guide_length, frame_idx, scale_factors):
|
| time_scale_factor, _, _ = scale_factors
|
| _, num_keyframes = get_keyframe_idxs(cond)
|
| latent_count = latent_length - num_keyframes
|
| frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
|
| if guide_length > 1:
|
| frame_idx = frame_idx // time_scale_factor * time_scale_factor
|
|
|
| latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor
|
|
|
| return frame_idx, latent_idx
|
|
|
| def add_keyframe_index(self, cond, frame_idx, guiding_latent, scale_factors):
|
| keyframe_idxs, _ = get_keyframe_idxs(cond)
|
| _, latent_coords = self._patchifier.patchify(guiding_latent)
|
| pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, True)
|
| pixel_coords[:, 0] += frame_idx
|
| if keyframe_idxs is None:
|
| keyframe_idxs = pixel_coords
|
| else:
|
| keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2)
|
| return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
|
|
|
| def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
|
| positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
|
| negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
|
|
|
| mask = torch.full(
|
| (noise_mask.shape[0], 1, guiding_latent.shape[2], 1, 1),
|
| 1.0 - strength,
|
| dtype=noise_mask.dtype,
|
| device=noise_mask.device,
|
| )
|
|
|
| latent_image = torch.cat([latent_image, guiding_latent], dim=2)
|
| noise_mask = torch.cat([noise_mask, mask], dim=2)
|
| return positive, negative, latent_image, noise_mask
|
|
|
| def replace_latent_frames(self, latent_image, noise_mask, guiding_latent, latent_idx, strength):
|
| cond_length = guiding_latent.shape[2]
|
| assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence."
|
|
|
| mask = torch.full(
|
| (noise_mask.shape[0], 1, cond_length, 1, 1),
|
| 1.0 - strength,
|
| dtype=noise_mask.dtype,
|
| device=noise_mask.device,
|
| )
|
|
|
| latent_image = latent_image.clone()
|
| noise_mask = noise_mask.clone()
|
|
|
| latent_image[:, :, latent_idx : latent_idx + cond_length] = guiding_latent
|
| noise_mask[:, :, latent_idx : latent_idx + cond_length] = mask
|
|
|
| return latent_image, noise_mask
|
|
|
| def generate(self, positive, negative, vae, latent, image, frame_idx, strength):
|
| scale_factors = vae.downscale_index_formula
|
| latent_image = latent["samples"]
|
| noise_mask = get_noise_mask(latent)
|
|
|
| _, _, latent_length, latent_height, latent_width = latent_image.shape
|
| image, t = self.encode(vae, latent_width, latent_height, image, scale_factors)
|
|
|
| frame_idx, latent_idx = self.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
| assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
|
|
| num_prefix_frames = min(self._num_prefix_frames, t.shape[2])
|
|
|
| positive, negative, latent_image, noise_mask = self.append_keyframe(
|
| positive,
|
| negative,
|
| frame_idx,
|
| latent_image,
|
| noise_mask,
|
| t[:, :, :num_prefix_frames],
|
| strength,
|
| scale_factors,
|
| )
|
|
|
| latent_idx += num_prefix_frames
|
|
|
| t = t[:, :, num_prefix_frames:]
|
| if t.shape[2] == 0:
|
| return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
|
|
|
| latent_image, noise_mask = self.replace_latent_frames(
|
| latent_image,
|
| noise_mask,
|
| t,
|
| latent_idx,
|
| strength,
|
| )
|
|
|
| return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
|
|
|
|
|
| class LTXVCropGuides:
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| return {"required": {"positive": ("CONDITIONING", ),
|
| "negative": ("CONDITIONING", ),
|
| "latent": ("LATENT",),
|
| }
|
| }
|
|
|
| RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
| RETURN_NAMES = ("positive", "negative", "latent")
|
|
|
| CATEGORY = "conditioning/video_models"
|
| FUNCTION = "crop"
|
|
|
| def __init__(self):
|
| self._patchifier = SymmetricPatchifier(1)
|
|
|
| def crop(self, positive, negative, latent):
|
| latent_image = latent["samples"].clone()
|
| noise_mask = get_noise_mask(latent)
|
|
|
| _, num_keyframes = get_keyframe_idxs(positive)
|
| if num_keyframes == 0:
|
| return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
|
|
|
| latent_image = latent_image[:, :, :-num_keyframes]
|
| noise_mask = noise_mask[:, :, :-num_keyframes]
|
|
|
| positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None})
|
| negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None})
|
|
|
| return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
|
|
|
|
|
| class LTXVConditioning:
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| return {"required": {"positive": ("CONDITIONING", ),
|
| "negative": ("CONDITIONING", ),
|
| "frame_rate": ("FLOAT", {"default": 25.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
| }}
|
| RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
| RETURN_NAMES = ("positive", "negative")
|
| FUNCTION = "append"
|
|
|
| CATEGORY = "conditioning/video_models"
|
|
|
| def append(self, positive, negative, frame_rate):
|
| positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate})
|
| negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate})
|
| return (positive, negative)
|
|
|
|
|
| class ModelSamplingLTXV:
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| return {"required": { "model": ("MODEL",),
|
| "max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
|
| "base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
|
| },
|
| "optional": {"latent": ("LATENT",), }
|
| }
|
|
|
| RETURN_TYPES = ("MODEL",)
|
| FUNCTION = "patch"
|
|
|
| CATEGORY = "advanced/model"
|
|
|
| def patch(self, model, max_shift, base_shift, latent=None):
|
| m = model.clone()
|
|
|
| if latent is None:
|
| tokens = 4096
|
| else:
|
| tokens = math.prod(latent["samples"].shape[2:])
|
|
|
| x1 = 1024
|
| x2 = 4096
|
| mm = (max_shift - base_shift) / (x2 - x1)
|
| b = base_shift - mm * x1
|
| shift = (tokens) * mm + b
|
|
|
| sampling_base = comfy.model_sampling.ModelSamplingFlux
|
| sampling_type = comfy.model_sampling.CONST
|
|
|
| class ModelSamplingAdvanced(sampling_base, sampling_type):
|
| pass
|
|
|
| model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
| model_sampling.set_parameters(shift=shift)
|
| m.add_object_patch("model_sampling", model_sampling)
|
|
|
| return (m, )
|
|
|
|
|
| class LTXVScheduler:
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| return {"required":
|
| {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
| "max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
|
| "base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
|
| "stretch": ("BOOLEAN", {
|
| "default": True,
|
| "tooltip": "Stretch the sigmas to be in the range [terminal, 1]."
|
| }),
|
| "terminal": (
|
| "FLOAT",
|
| {
|
| "default": 0.1, "min": 0.0, "max": 0.99, "step": 0.01,
|
| "tooltip": "The terminal value of the sigmas after stretching."
|
| },
|
| ),
|
| },
|
| "optional": {"latent": ("LATENT",), }
|
| }
|
|
|
| RETURN_TYPES = ("SIGMAS",)
|
| CATEGORY = "sampling/custom_sampling/schedulers"
|
|
|
| FUNCTION = "get_sigmas"
|
|
|
| def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None):
|
| if latent is None:
|
| tokens = 4096
|
| else:
|
| tokens = math.prod(latent["samples"].shape[2:])
|
|
|
| sigmas = torch.linspace(1.0, 0.0, steps + 1)
|
|
|
| x1 = 1024
|
| x2 = 4096
|
| mm = (max_shift - base_shift) / (x2 - x1)
|
| b = base_shift - mm * x1
|
| sigma_shift = (tokens) * mm + b
|
|
|
| power = 1
|
| sigmas = torch.where(
|
| sigmas != 0,
|
| math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
|
| 0,
|
| )
|
|
|
|
|
| if stretch:
|
| non_zero_mask = sigmas != 0
|
| non_zero_sigmas = sigmas[non_zero_mask]
|
| one_minus_z = 1.0 - non_zero_sigmas
|
| scale_factor = one_minus_z[-1] / (1.0 - terminal)
|
| stretched = 1.0 - (one_minus_z / scale_factor)
|
| sigmas[non_zero_mask] = stretched
|
|
|
| return (sigmas,)
|
|
|
| def encode_single_frame(output_file, image_array: np.ndarray, crf):
|
| container = av.open(output_file, "w", format="mp4")
|
| try:
|
| stream = container.add_stream(
|
| "h264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
|
| )
|
| stream.height = image_array.shape[0]
|
| stream.width = image_array.shape[1]
|
| av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(
|
| format="yuv420p"
|
| )
|
| container.mux(stream.encode(av_frame))
|
| container.mux(stream.encode())
|
| finally:
|
| container.close()
|
|
|
|
|
| def decode_single_frame(video_file):
|
| container = av.open(video_file)
|
| try:
|
| stream = next(s for s in container.streams if s.type == "video")
|
| frame = next(container.decode(stream))
|
| finally:
|
| container.close()
|
| return frame.to_ndarray(format="rgb24")
|
|
|
|
|
| def preprocess(image: torch.Tensor, crf=29):
|
| if crf == 0:
|
| return image
|
|
|
| image_array = (image[:(image.shape[0] // 2) * 2, :(image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy()
|
| with io.BytesIO() as output_file:
|
| encode_single_frame(output_file, image_array, crf)
|
| video_bytes = output_file.getvalue()
|
| with io.BytesIO(video_bytes) as video_file:
|
| image_array = decode_single_frame(video_file)
|
| tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
|
| return tensor
|
|
|
|
|
| class LTXVPreprocess:
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| return {
|
| "required": {
|
| "image": ("IMAGE",),
|
| "img_compression": (
|
| "INT",
|
| {
|
| "default": 35,
|
| "min": 0,
|
| "max": 100,
|
| "tooltip": "Amount of compression to apply on image.",
|
| },
|
| ),
|
| }
|
| }
|
|
|
| FUNCTION = "preprocess"
|
| RETURN_TYPES = ("IMAGE",)
|
| RETURN_NAMES = ("output_image",)
|
| CATEGORY = "image"
|
|
|
| def preprocess(self, image, img_compression):
|
| output_images = []
|
| for i in range(image.shape[0]):
|
| output_images.append(preprocess(image[i], img_compression))
|
| return (torch.stack(output_images),)
|
|
|
|
|
| NODE_CLASS_MAPPINGS = {
|
| "EmptyLTXVLatentVideo": EmptyLTXVLatentVideo,
|
| "LTXVImgToVideo": LTXVImgToVideo,
|
| "ModelSamplingLTXV": ModelSamplingLTXV,
|
| "LTXVConditioning": LTXVConditioning,
|
| "LTXVScheduler": LTXVScheduler,
|
| "LTXVAddGuide": LTXVAddGuide,
|
| "LTXVPreprocess": LTXVPreprocess,
|
| "LTXVCropGuides": LTXVCropGuides,
|
| }
|
|
|