VibeToken / modeling /vibetoken_model.py
APGASU's picture
scripts
7bef20f verified
"""VibeToken model definition."""
import torch
import torch.nn as nn
from einops import rearrange
from modeling.modules.base_model import BaseModel
from modeling.modules.encoder_decoder import ResolutionEncoder, ResolutionDecoder
from modeling.quantizer import VectorQuantizer, DiagonalGaussianDistribution, VectorQuantizerMVQ, SoftVectorQuantizer
from modeling.modules.maskgit_vqgan import Encoder as Pixel_Eecoder
from modeling.modules.maskgit_vqgan import Decoder as Pixel_Decoder
from modeling.modules.maskgit_vqgan import VectorQuantizer as Pixel_Quantizer
import json
from omegaconf import OmegaConf
from pathlib import Path
from huggingface_hub import PyTorchModelHubMixin
class PretrainedTokenizer(nn.Module):
def __init__(self, pretrained_weight):
super().__init__()
conf = OmegaConf.create(
{"channel_mult": [1, 1, 2, 2, 4],
"num_resolutions": 5,
"dropout": 0.0,
"hidden_channels": 128,
"num_channels": 3,
"num_res_blocks": 2,
"resolution": 256,
"z_channels": 256})
self.encoder = Pixel_Eecoder(conf)
self.decoder = Pixel_Decoder(conf)
self.quantize = Pixel_Quantizer(
num_embeddings=1024, embedding_dim=256, commitment_cost=0.25)
# Load pretrained weights
self.load_state_dict(torch.load(pretrained_weight, map_location=torch.device("cpu")), strict=True)
self.eval()
for param in self.parameters():
param.requires_grad = False
@torch.no_grad()
def encode(self, x):
hidden_states = self.encoder(x)
quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states)
return codebook_indices.detach()
@torch.no_grad()
def decode(self, codes):
quantized_states = self.quantize.get_codebook_entry(codes)
rec_images = self.decoder(quantized_states)
rec_images = torch.clamp(rec_images, 0.0, 1.0)
return rec_images.detach()
@torch.no_grad()
def decode_tokens(self, codes):
return self.decode(codes)
class VibeTokenModel(BaseModel, PyTorchModelHubMixin, tags=["arxiv:2406.07550", "image-tokenization"]):
def __init__(self, config):
if isinstance(config, dict):
config = OmegaConf.create(config)
super().__init__()
self.config = config
# This should be False for stage1 and True for stage2.
self.finetune_decoder = config.model.vq_model.get("finetune_decoder", True)
self.quantize_mode = config.model.vq_model.get("quantize_mode", "vq")
if self.quantize_mode not in ["vq", "vae", "softvq", "mvq"]:
raise ValueError(f"Unsupported quantize mode {self.quantize_mode}.")
if self.finetune_decoder and self.quantize_mode not in ["vq", "softvq", "mvq"]:
raise ValueError("Only supprot finetune_decoder with vq quantization for now.")
self.encoder = ResolutionEncoder(config)
self.decoder = ResolutionDecoder(config)
self.num_latent_tokens = config.model.vq_model.num_latent_tokens
scale = self.encoder.width ** -0.5
self.latent_tokens = nn.Parameter(
scale * torch.randn(self.num_latent_tokens, self.encoder.width))
self.apply(self._init_weights)
if self.quantize_mode == "vq":
self.quantize = VectorQuantizer(
codebook_size=config.model.vq_model.codebook_size,
token_size=config.model.vq_model.token_size,
commitment_cost=config.model.vq_model.commitment_cost,
use_l2_norm=config.model.vq_model.use_l2_norm,)
elif self.quantize_mode == "vae":
self.quantize = DiagonalGaussianDistribution
elif self.quantize_mode == "mvq":
self.quantize = VectorQuantizerMVQ(
codebook_size=config.model.vq_model.codebook_size,
token_size=config.model.vq_model.token_size,
commitment_cost=config.model.vq_model.commitment_cost,
use_l2_norm=config.model.vq_model.use_l2_norm,
num_codebooks=config.model.vq_model.num_codebooks,
)
elif self.quantize_mode == "softvq":
self.quantize = SoftVectorQuantizer(
codebook_size=config.model.vq_model.codebook_size,
token_size=config.model.vq_model.token_size,
commitment_cost=config.model.vq_model.commitment_cost,
use_l2_norm=config.model.vq_model.use_l2_norm,
num_codebooks=config.model.vq_model.num_codebooks,
)
else:
raise NotImplementedError
if self.finetune_decoder:
# Freeze encoder/quantizer/latent tokens
self.latent_tokens.requires_grad_(False)
self.encoder.eval()
self.encoder.requires_grad_(False)
self.quantize.eval()
self.quantize.requires_grad_(False)
# Include MaskGiT-VQGAN's quantizer and decoder
self.pixel_quantize = Pixel_Quantizer(
num_embeddings=1024, embedding_dim=256, commitment_cost=0.25)
self.pixel_decoder = Pixel_Decoder(OmegaConf.create(
{"channel_mult": [1, 1, 2, 2, 4],
"num_resolutions": 5,
"dropout": 0.0,
"hidden_channels": 128,
"num_channels": 3,
"num_res_blocks": 2,
"resolution": 256,
"z_channels": 256}))
def _save_pretrained(self, save_directory: Path) -> None:
"""Save weights and config to a local directory."""
# Assume 'self.config' is your DictConfig object
# Convert to a regular dictionary
dict_config = OmegaConf.to_container(self.config)
# Save as JSON
file_path = Path(save_directory) / "config.json"
with open(file_path, 'w') as json_file:
json.dump(dict_config, json_file, indent=4)
super()._save_pretrained(save_directory)
def _init_weights(self, module):
""" Initialize the weights.
:param:
module -> torch.nn.Module: module to initialize
"""
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d) or isinstance(module, nn.Conv2d):
module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def encode(self, x, attention_mask=None, encode_patch_size=None, train=True, length=None):
if self.finetune_decoder:
with torch.no_grad():
self.encoder.eval()
self.quantize.eval()
z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens, attention_mask=attention_mask, encode_patch_size=encode_patch_size, train=train)
z_quantized, result_dict = self.quantize(z)
result_dict["quantizer_loss"] *= 0
result_dict["commitment_loss"] *= 0
result_dict["codebook_loss"] *= 0
else:
if length is not None:
attention_mask = None
latent_tokens = self.latent_tokens[:length+1]
else:
latent_tokens = self.latent_tokens
z = self.encoder(pixel_values=x, latent_tokens=latent_tokens, attention_mask=attention_mask, encode_patch_size=encode_patch_size, train=train)
if self.quantize_mode in ["vq", "mvq", "softvq"]:
z_quantized, result_dict = self.quantize(z)
elif self.quantize_mode == "vae":
posteriors = self.quantize(z)
z_quantized = posteriors.sample()
result_dict = posteriors
return z_quantized, result_dict
def decode(self, z_quantized, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True):
decoded = self.decoder(z_quantized, attention_mask=attention_mask, height=height, width=width, decode_patch_size=decode_patch_size, train=train)
if self.finetune_decoder:
quantized_states = torch.einsum(
'nchw,cd->ndhw', decoded.softmax(1),
self.pixel_quantize.embedding.weight)
decoded = self.pixel_decoder(quantized_states)
return decoded
def decode_tokens(self, tokens, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True):
if self.quantize_mode in ["vq", "softvq"]:
tokens = tokens.squeeze(1)
batch, seq_len = tokens.shape # B x N
z_quantized = self.quantize.get_codebook_entry(
tokens.reshape(-1)).reshape(batch, 1, seq_len, -1)
z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous()
elif self.quantize_mode == "mvq":
z_quantized = self.quantize.get_codebook_entry(tokens)
elif self.quantize_mode == "vae":
z_quantized = tokens
z_quantized = z_quantized.to(self.decoder.decoder_embed.weight.dtype)
decoded = self.decode(z_quantized, attention_mask=attention_mask, height=height, width=width, decode_patch_size=decode_patch_size, train=train)
return decoded
def forward(self, x, key_attention_mask=None, height=None, width=None, train=True):
if height is None:
batch_size, channels, height, width = x.shape
z_quantized, result_dict = self.encode(x, attention_mask=key_attention_mask, train=train)
z_quantized = z_quantized.to(self.decoder.decoder_embed.weight.dtype)
decoded = self.decode(z_quantized, attention_mask=key_attention_mask, height=height, width=width, train=train)
return decoded, result_dict