Training Guide
This guide provides simple snippets to train diffnext models.
1. Build VQVAE cache
To optimize training workflow, we preprocess images or videos into VQVAE latents.
Requirements:
pip install protobuf==3.20.3 codewithgpu decord
Build T2I cache
Following snippet can be used to cache image latents:
import os, codewithgpu, torch, PIL.Image, numpy as np
from diffnext.models.autoencoders.autoencoder_vq import AutoencoderVQ
device, dtype = torch.device("cuda"), torch.float16
vae = AutoencoderVQ.from_pretrained("/path/to/BAAI/URSA-1.7B-IBQ1024/vae")
vae = vae.to(device=device, dtype=dtype).eval()
features = {"codes": "bytes", "caption": "string", "text": "string", "shape": ["int64"]}
os.makedirs("./datasets/ibq1024_dataset", exist_ok=True)
writer = codewithgpu.RecordWriter("./datasets/ibq1024_dataset", features)
img = PIL.Image.open("./assets/sample_image.jpg")
x = torch.as_tensor(np.array(img)[None, ...].transpose(0, 3, 1, 2)).to(device).to(dtype)
with torch.no_grad():
x = vae.encode(x.sub(127.5).div(127.5)).latent_dist.parameters.unsqueeze(1).cpu().numpy()[0]
example = {"caption": "long caption", "text": "short text"}
# Ensure enough examples for codewithgou distributed dataset.
[writer.write({"shape": x.shape, "codes": x.tobytes(), **example}) for _ in range(16)]
writer.close()
Build T2V cache
Following snippet can be used to cache video latents:
import os, codewithgpu, torch, decord, numpy as np
from diffnext.models.autoencoders.autoencoder_vq_cosmos3d import AutoencoderVQCosmos3D
device, dtype = torch.device("cuda"), torch.float16
vae = AutoencoderVQCosmos3D.from_pretrained("/path/to/URSA-1.7B-FSQ320/vae")
vae = vae.to(device=device, dtype=dtype).eval()
features = {"codes": "bytes", "caption": "string", "text": "string", "shape": ["int64"], "flow": "float64"}
os.makedirs("./datasets/fsq320_dataset", exist_ok=True)
writer = codewithgpu.RecordWriter("./datasets/fsq320_dataset", features)
resize, crop_size, frame_ids = 320, (320, 512), list(range(0, 97, 2))
vid = decord.VideoReader("./assets/sample_video.mp4")
h, w = vid[0].shape[:2]
scale = float(resize) / float(min(h, w))
size = int(h * scale + 0.5), int(w * scale + 0.5)
y, x = (size[0] - crop_size[0]) // 2, (size[1] - crop_size[1]) // 2
vid = decord.VideoReader("./assets/sample_video.mp4", height=size[0], width=size[1])
vid = vid.get_batch(frame_ids).asnumpy()
vid = vid[:, y : y + crop_size[0], x : x + crop_size[1]]
x = torch.as_tensor(vid[None, ...].transpose((0, 4, 1, 2, 3))).to(device).to(dtype)
with torch.no_grad():
x = vae.encode(x.sub(127.5).div(127.5)).latent_dist.parameters.cpu().numpy()[0]
example = {"caption": "long caption", "text": "short text", "flow": 9}
# Ensure enough examples for codewithgou distributed dataset.
[writer.write({"shape": x.shape, "codes": x.tobytes(), **example}) for _ in range(16)]
writer.close()
2. Train models
Train T2I model
Following snippet provides simple T2I training arguments:
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml \
--machine_rank 0 --num_machines 1 --num_processes 8 \
scripts/train.py \
config="./configs/ursa_1.7b_ibq1024.yaml" \
experiment.name="ursa_1.7b_ibq1024" \
experiment.output_dir="./experiments/ursa_1.7b_ibq1024" \
pipeline.paths.pretrained_path="/path/to/URSA-1.7B-IBQ1024" \
train_dataloader.params.dataset="./datasets/ibq1024_dataset" \
model.gradient_checkpointing=3 \
training.batch_size=4 \
trainin.gradient_accumulation_steps=16
Train T2V model
Following snippet provides simple T2V training arguments:
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml \
--machine_rank 0 --num_machines 1 --num_processes 8 \
scripts/train.py \
config="./configs/ursa_1.7b_fsq320.yaml" \
experiment.name="ursa_1.7b_fsq320" \
experiment.output_dir="./experiments/ursa_1.7b_fsq320" \
pipeline.paths.pretrained_path="/path/to/URSA-1.7B-FSQ320" \
train_dataloader.params.dataset="./datasets/fsq320_dataset" \
model.gradient_checkpointing=3 \
training.batch_size=1 \
trainin.gradient_accumulation_steps=32