World_Model / URSA /docs /training.md
BryanW's picture
Add files using upload-large-folder tool
2ee4cd6 verified

Training Guide

This guide provides simple snippets to train diffnext models.

1. Build VQVAE cache

To optimize training workflow, we preprocess images or videos into VQVAE latents.

Requirements:

pip install protobuf==3.20.3 codewithgpu decord

Build T2I cache

Following snippet can be used to cache image latents:

import os, codewithgpu, torch, PIL.Image, numpy as np
from diffnext.models.autoencoders.autoencoder_vq import AutoencoderVQ

device, dtype = torch.device("cuda"), torch.float16
vae = AutoencoderVQ.from_pretrained("/path/to/BAAI/URSA-1.7B-IBQ1024/vae")
vae = vae.to(device=device, dtype=dtype).eval()

features = {"codes": "bytes", "caption": "string", "text": "string", "shape": ["int64"]}
os.makedirs("./datasets/ibq1024_dataset", exist_ok=True)
writer = codewithgpu.RecordWriter("./datasets/ibq1024_dataset", features)

img = PIL.Image.open("./assets/sample_image.jpg")
x = torch.as_tensor(np.array(img)[None, ...].transpose(0, 3, 1, 2)).to(device).to(dtype)
with torch.no_grad():
    x = vae.encode(x.sub(127.5).div(127.5)).latent_dist.parameters.unsqueeze(1).cpu().numpy()[0]
example = {"caption": "long caption", "text": "short text"}
# Ensure enough examples for codewithgou distributed dataset.
[writer.write({"shape": x.shape, "codes": x.tobytes(), **example}) for _ in range(16)]
writer.close()

Build T2V cache

Following snippet can be used to cache video latents:

import os, codewithgpu, torch, decord, numpy as np
from diffnext.models.autoencoders.autoencoder_vq_cosmos3d import AutoencoderVQCosmos3D

device, dtype = torch.device("cuda"), torch.float16
vae = AutoencoderVQCosmos3D.from_pretrained("/path/to/URSA-1.7B-FSQ320/vae")
vae = vae.to(device=device, dtype=dtype).eval()

features = {"codes": "bytes", "caption": "string", "text": "string", "shape": ["int64"], "flow": "float64"}
os.makedirs("./datasets/fsq320_dataset", exist_ok=True)
writer = codewithgpu.RecordWriter("./datasets/fsq320_dataset", features)

resize, crop_size, frame_ids = 320, (320, 512), list(range(0, 97, 2))
vid = decord.VideoReader("./assets/sample_video.mp4")
h, w = vid[0].shape[:2]
scale = float(resize) / float(min(h, w))
size = int(h * scale + 0.5), int(w * scale + 0.5)
y, x = (size[0] - crop_size[0]) // 2, (size[1] - crop_size[1]) // 2
vid = decord.VideoReader("./assets/sample_video.mp4", height=size[0], width=size[1])
vid = vid.get_batch(frame_ids).asnumpy()
vid = vid[:, y : y + crop_size[0], x : x + crop_size[1]]
x = torch.as_tensor(vid[None, ...].transpose((0, 4, 1, 2, 3))).to(device).to(dtype)
with torch.no_grad():
    x = vae.encode(x.sub(127.5).div(127.5)).latent_dist.parameters.cpu().numpy()[0]
example = {"caption": "long caption", "text": "short text", "flow": 9}
# Ensure enough examples for codewithgou distributed dataset.
[writer.write({"shape": x.shape, "codes": x.tobytes(), **example}) for _ in range(16)]
writer.close()

2. Train models

Train T2I model

Following snippet provides simple T2I training arguments:

accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml \
--machine_rank 0 --num_machines 1 --num_processes 8 \
scripts/train.py \
config="./configs/ursa_1.7b_ibq1024.yaml" \
experiment.name="ursa_1.7b_ibq1024" \
experiment.output_dir="./experiments/ursa_1.7b_ibq1024" \
pipeline.paths.pretrained_path="/path/to/URSA-1.7B-IBQ1024" \
train_dataloader.params.dataset="./datasets/ibq1024_dataset" \
model.gradient_checkpointing=3 \
training.batch_size=4 \
trainin.gradient_accumulation_steps=16

Train T2V model

Following snippet provides simple T2V training arguments:

accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml \
--machine_rank 0 --num_machines 1 --num_processes 8 \
scripts/train.py \
config="./configs/ursa_1.7b_fsq320.yaml" \
experiment.name="ursa_1.7b_fsq320" \
experiment.output_dir="./experiments/ursa_1.7b_fsq320" \
pipeline.paths.pretrained_path="/path/to/URSA-1.7B-FSQ320" \
train_dataloader.params.dataset="./datasets/fsq320_dataset" \
model.gradient_checkpointing=3 \
training.batch_size=1 \
trainin.gradient_accumulation_steps=32