| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
|
|
| import imageio |
| import numpy as np |
| import torch |
|
|
| from ar_model import AutoRegressiveModel |
| from text2world_prompt_upsampler_inference import ( |
| create_prompt_upsampler, |
| run_chat_completion, |
| ) |
| from presets import ( |
| create_text_guardrail_runner, |
| create_video_guardrail_runner, |
| run_text_guardrail, |
| run_video_guardrail, |
| ) |
| from .log import log |
|
|
|
|
| def get_upsampled_prompt( |
| prompt_upsampler_model: AutoRegressiveModel, input_prompt: str, temperature: float = 0.01 |
| ) -> str: |
| """ |
| Get upsampled prompt from the prompt upsampler model instance. |
| |
| Args: |
| prompt_upsampler_model: The prompt upsampler model instance. |
| input_prompt (str): Original prompt to upsample. |
| temperature (float): Temperature for generation (default: 0.01). |
| |
| Returns: |
| str: The upsampled prompt. |
| """ |
| dialogs = [ |
| [ |
| { |
| "role": "user", |
| "content": f"Upsample the short caption to a long caption: {input_prompt}", |
| } |
| ] |
| ] |
|
|
| upsampled_prompt = run_chat_completion(prompt_upsampler_model, dialogs, temperature=temperature) |
| return upsampled_prompt |
|
|
|
|
| def print_rank_0(string: str): |
| rank = torch.distributed.get_rank() |
| if rank == 0: |
| log.info(string) |
|
|
|
|
| def process_prompt( |
| prompt: str, |
| checkpoint_dir: str, |
| prompt_upsampler_dir: str, |
| guardrails_dir: str, |
| image_path: str = None, |
| enable_prompt_upsampler: bool = True, |
| ) -> str: |
| """ |
| Handle prompt upsampling if enabled, then run guardrails to ensure safety. |
| |
| Args: |
| prompt (str): The original text prompt. |
| checkpoint_dir (str): Base checkpoint directory. |
| prompt_upsampler_dir (str): Directory containing prompt upsampler weights. |
| guardrails_dir (str): Directory containing guardrails weights. |
| image_path (str, optional): Path to an image, if any (not implemented for upsampling). |
| enable_prompt_upsampler (bool): Whether to enable prompt upsampling. |
| |
| Returns: |
| str: The upsampled prompt or original prompt if upsampling is disabled or fails. |
| """ |
|
|
| text_guardrail = create_text_guardrail_runner(os.path.join(checkpoint_dir, guardrails_dir)) |
|
|
| |
| is_safe = run_text_guardrail(str(prompt), text_guardrail) |
| if not is_safe: |
| raise ValueError("Guardrail blocked world generation.") |
|
|
| if enable_prompt_upsampler: |
| if image_path: |
| raise NotImplementedError("Prompt upsampling is not supported for image generation") |
| else: |
| prompt_upsampler = create_prompt_upsampler( |
| checkpoint_dir=os.path.join(checkpoint_dir, prompt_upsampler_dir) |
| ) |
| upsampled_prompt = get_upsampled_prompt(prompt_upsampler, prompt) |
| print_rank_0(f"Original prompt: {prompt}\nUpsampled prompt: {upsampled_prompt}\n") |
| del prompt_upsampler |
|
|
| |
| is_safe = run_text_guardrail(str(upsampled_prompt), text_guardrail) |
| if not is_safe: |
| raise ValueError("Guardrail blocked world generation.") |
|
|
| return upsampled_prompt |
| else: |
| return prompt |
|
|
|
|
| def save_video( |
| grid: np.ndarray, |
| fps: int, |
| H: int, |
| W: int, |
| video_save_quality: int, |
| video_save_path: str, |
| checkpoint_dir: str, |
| guardrails_dir: str, |
| ): |
| """ |
| Save video frames to file, applying a safety check before writing. |
| |
| Args: |
| grid (np.ndarray): Video frames array [T, H, W, C]. |
| fps (int): Frames per second. |
| H (int): Frame height. |
| W (int): Frame width. |
| video_save_quality (int): Video encoding quality (0-10). |
| video_save_path (str): Output video file path. |
| checkpoint_dir (str): Directory containing model checkpoints. |
| guardrails_dir (str): Directory containing guardrails weights. |
| """ |
| video_classifier_guardrail = create_video_guardrail_runner(os.path.join(checkpoint_dir, guardrails_dir)) |
|
|
| |
| grid = run_video_guardrail(grid, video_classifier_guardrail) |
|
|
| kwargs = { |
| "fps": fps, |
| "quality": video_save_quality, |
| "macro_block_size": 1, |
| "ffmpeg_params": ["-s", f"{W}x{H}"], |
| "output_params": ["-f", "mp4"], |
| } |
|
|
| imageio.mimsave(video_save_path, grid, "mp4", **kwargs) |
|
|