| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Class-to-Image generation pipeline for NOVA.""" |
|
|
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline |
| import numpy as np |
| import torch |
|
|
| from diffnext.image_processor import VaeImageProcessor |
| from diffnext.pipelines.pipeline_utils import NOVAPipelineOutput, PipelineMixin |
|
|
|
|
| class NOVAC2IPipeline(DiffusionPipeline, PipelineMixin): |
| """NOVA C2I generation pipeline.""" |
|
|
| _optional_components = ["transformer", "scheduler", "vae"] |
|
|
| def __init__(self, transformer=None, scheduler=None, vae=None, trust_remote_code=True): |
| super(NOVAC2IPipeline, self).__init__() |
| self.vae = self.register_module(vae, "vae") |
| self.transformer = self.register_module(transformer, "transformer") |
| self.scheduler = self.register_module(scheduler, "scheduler") |
| self.transformer.sample_scheduler, self.guidance_scale = self.scheduler, 5.0 |
| self.image_processor = VaeImageProcessor() |
|
|
| @torch.no_grad() |
| def __call__( |
| self, |
| prompt=None, |
| num_inference_steps=64, |
| num_diffusion_steps=25, |
| guidance_scale=5, |
| min_guidance_scale=None, |
| negative_prompt=None, |
| num_images_per_prompt=1, |
| generator=None, |
| latents=None, |
| disable_progress_bar=False, |
| output_type="pil", |
| **kwargs, |
| ) -> NOVAPipelineOutput: |
| """The call function to the pipeline for generation. |
| |
| Args: |
| prompt (int or List[int], *optional*): |
| The prompt to be encoded. |
| num_inference_steps (int, *optional*, defaults to 64): |
| The number of autoregressive steps. |
| num_diffusion_steps (int, *optional*, defaults to 25): |
| The number of denoising steps. |
| guidance_scale (float, *optional*, defaults to 5): |
| The classifier guidance scale. |
| min_guidance_scale (float, *optional*): |
| The minimum classifier guidance scale. |
| negative_prompt (int or List[int], *optional*): |
| The prompt or prompts to guide what to not include in image generation. |
| num_images_per_prompt (int, *optional*, defaults to 1): |
| The number of images that should be generated per prompt. |
| generator (torch.Generator, *optional*): |
| The random generator. |
| disable_progress_bar (bool, *optional*) |
| Whether to disable all progress bars. |
| output_type (str, *optional*, defaults to `"pil"`): |
| The output format of the generated image. Choose between `PIL.Image` or `np.array`. |
| |
| Returns: |
| NOVAPipelineOutput: The pipeline output. |
| """ |
| self.guidance_scale = guidance_scale |
| inputs = {"generator": generator, **locals()} |
| num_patches = int(np.prod(self.transformer.config.image_base_size)) |
| mask_ratios = np.cos(0.5 * np.pi * np.arange(num_inference_steps + 1) / num_inference_steps) |
| mask_length = np.round(mask_ratios * num_patches).astype("int64") |
| inputs["num_preds"] = mask_length[:-1] - mask_length[1:] |
| inputs["tqdm1"], inputs["tqdm2"], inputs["latents"] = False, not disable_progress_bar, [] |
| inputs["c"] = [self.encode_prompt(**dict(_ for _ in inputs.items() if "prompt" in _[0]))] |
| inputs["batch_size"] = len(inputs["c"][0]) // (2 if guidance_scale > 1 else 1) |
| _, outputs = inputs.pop("self"), self.transformer(inputs) |
| if output_type != "latent": |
| outputs["x"] = self.image_processor.decode_latents(self.vae, outputs["x"]) |
| outputs["x"] = self.image_processor.postprocess(outputs["x"], output_type) |
| return NOVAPipelineOutput(**{"images": outputs["x"]}) |
|
|
| def encode_prompt( |
| self, |
| prompt, |
| num_images_per_prompt=1, |
| negative_prompt=None, |
| ) -> torch.Tensor: |
| """Encode class prompts. |
| |
| Args: |
| prompt (int or List[int], *optional*): |
| The prompt to be encoded. |
| num_images_per_prompt (int, *optional*, defaults to 1): |
| The number of images that should be generated per prompt. |
| negative_prompt (int or List[int], *optional*): |
| The prompt or prompts to guide what to not include in image generation. |
| |
| Returns: |
| torch.Tensor: The prompt index. |
| """ |
|
|
| def select_or_pad(a, b, n=1): |
| return [a or b] * n if isinstance(a or b, int) else (a or b) |
|
|
| num_classes = self.transformer.label_embed.num_classes |
| prompt = [prompt] if isinstance(prompt, int) else prompt |
| negative_prompt = select_or_pad(negative_prompt, num_classes, len(prompt)) |
| prompts = prompt + (negative_prompt if self.guidance_scale > 1 else []) |
| c = self.transformer.label_embed(torch.as_tensor(prompts, device=self.device)) |
| return c.repeat_interleave(num_images_per_prompt, dim=0) |
|
|