| """VibeToken model definition.""" |
|
|
| import torch |
| import torch.nn as nn |
| from einops import rearrange |
|
|
| from modeling.modules.base_model import BaseModel |
| from modeling.modules.encoder_decoder import ResolutionEncoder, ResolutionDecoder |
| from modeling.quantizer import VectorQuantizer, DiagonalGaussianDistribution, VectorQuantizerMVQ, SoftVectorQuantizer |
| from modeling.modules.maskgit_vqgan import Encoder as Pixel_Eecoder |
| from modeling.modules.maskgit_vqgan import Decoder as Pixel_Decoder |
| from modeling.modules.maskgit_vqgan import VectorQuantizer as Pixel_Quantizer |
| import json |
| from omegaconf import OmegaConf |
| from pathlib import Path |
|
|
| from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
| class PretrainedTokenizer(nn.Module): |
| def __init__(self, pretrained_weight): |
| super().__init__() |
| conf = OmegaConf.create( |
| {"channel_mult": [1, 1, 2, 2, 4], |
| "num_resolutions": 5, |
| "dropout": 0.0, |
| "hidden_channels": 128, |
| "num_channels": 3, |
| "num_res_blocks": 2, |
| "resolution": 256, |
| "z_channels": 256}) |
| self.encoder = Pixel_Eecoder(conf) |
| self.decoder = Pixel_Decoder(conf) |
| self.quantize = Pixel_Quantizer( |
| num_embeddings=1024, embedding_dim=256, commitment_cost=0.25) |
| |
| self.load_state_dict(torch.load(pretrained_weight, map_location=torch.device("cpu")), strict=True) |
| |
| self.eval() |
| for param in self.parameters(): |
| param.requires_grad = False |
| |
| @torch.no_grad() |
| def encode(self, x): |
| hidden_states = self.encoder(x) |
| quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states) |
| return codebook_indices.detach() |
| |
| @torch.no_grad() |
| def decode(self, codes): |
| quantized_states = self.quantize.get_codebook_entry(codes) |
| rec_images = self.decoder(quantized_states) |
| rec_images = torch.clamp(rec_images, 0.0, 1.0) |
| return rec_images.detach() |
| |
| @torch.no_grad() |
| def decode_tokens(self, codes): |
| return self.decode(codes) |
|
|
|
|
| class VibeTokenModel(BaseModel, PyTorchModelHubMixin, tags=["arxiv:2406.07550", "image-tokenization"]): |
| def __init__(self, config): |
|
|
| if isinstance(config, dict): |
| config = OmegaConf.create(config) |
|
|
| super().__init__() |
| self.config = config |
| |
| self.finetune_decoder = config.model.vq_model.get("finetune_decoder", True) |
|
|
| self.quantize_mode = config.model.vq_model.get("quantize_mode", "vq") |
| if self.quantize_mode not in ["vq", "vae", "softvq", "mvq"]: |
| raise ValueError(f"Unsupported quantize mode {self.quantize_mode}.") |
| |
| if self.finetune_decoder and self.quantize_mode not in ["vq", "softvq", "mvq"]: |
| raise ValueError("Only supprot finetune_decoder with vq quantization for now.") |
|
|
| self.encoder = ResolutionEncoder(config) |
| self.decoder = ResolutionDecoder(config) |
| |
| self.num_latent_tokens = config.model.vq_model.num_latent_tokens |
| scale = self.encoder.width ** -0.5 |
| self.latent_tokens = nn.Parameter( |
| scale * torch.randn(self.num_latent_tokens, self.encoder.width)) |
| |
| self.apply(self._init_weights) |
|
|
| if self.quantize_mode == "vq": |
| self.quantize = VectorQuantizer( |
| codebook_size=config.model.vq_model.codebook_size, |
| token_size=config.model.vq_model.token_size, |
| commitment_cost=config.model.vq_model.commitment_cost, |
| use_l2_norm=config.model.vq_model.use_l2_norm,) |
| elif self.quantize_mode == "vae": |
| self.quantize = DiagonalGaussianDistribution |
| elif self.quantize_mode == "mvq": |
| self.quantize = VectorQuantizerMVQ( |
| codebook_size=config.model.vq_model.codebook_size, |
| token_size=config.model.vq_model.token_size, |
| commitment_cost=config.model.vq_model.commitment_cost, |
| use_l2_norm=config.model.vq_model.use_l2_norm, |
| num_codebooks=config.model.vq_model.num_codebooks, |
| ) |
| elif self.quantize_mode == "softvq": |
| self.quantize = SoftVectorQuantizer( |
| codebook_size=config.model.vq_model.codebook_size, |
| token_size=config.model.vq_model.token_size, |
| commitment_cost=config.model.vq_model.commitment_cost, |
| use_l2_norm=config.model.vq_model.use_l2_norm, |
| num_codebooks=config.model.vq_model.num_codebooks, |
| ) |
| else: |
| raise NotImplementedError |
| |
| if self.finetune_decoder: |
| |
| self.latent_tokens.requires_grad_(False) |
| self.encoder.eval() |
| self.encoder.requires_grad_(False) |
| self.quantize.eval() |
| self.quantize.requires_grad_(False) |
|
|
| |
| self.pixel_quantize = Pixel_Quantizer( |
| num_embeddings=1024, embedding_dim=256, commitment_cost=0.25) |
| self.pixel_decoder = Pixel_Decoder(OmegaConf.create( |
| {"channel_mult": [1, 1, 2, 2, 4], |
| "num_resolutions": 5, |
| "dropout": 0.0, |
| "hidden_channels": 128, |
| "num_channels": 3, |
| "num_res_blocks": 2, |
| "resolution": 256, |
| "z_channels": 256})) |
| |
| def _save_pretrained(self, save_directory: Path) -> None: |
| """Save weights and config to a local directory.""" |
| |
| |
| dict_config = OmegaConf.to_container(self.config) |
| |
| file_path = Path(save_directory) / "config.json" |
| with open(file_path, 'w') as json_file: |
| json.dump(dict_config, json_file, indent=4) |
| super()._save_pretrained(save_directory) |
|
|
| def _init_weights(self, module): |
| """ Initialize the weights. |
| :param: |
| module -> torch.nn.Module: module to initialize |
| """ |
| if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d) or isinstance(module, nn.Conv2d): |
| module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02) |
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
|
|
| def encode(self, x, attention_mask=None, encode_patch_size=None, train=True, length=None): |
| if self.finetune_decoder: |
| with torch.no_grad(): |
| self.encoder.eval() |
| self.quantize.eval() |
| z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens, attention_mask=attention_mask, encode_patch_size=encode_patch_size, train=train) |
| z_quantized, result_dict = self.quantize(z) |
| result_dict["quantizer_loss"] *= 0 |
| result_dict["commitment_loss"] *= 0 |
| result_dict["codebook_loss"] *= 0 |
| else: |
| if length is not None: |
| attention_mask = None |
| latent_tokens = self.latent_tokens[:length+1] |
| else: |
| latent_tokens = self.latent_tokens |
| z = self.encoder(pixel_values=x, latent_tokens=latent_tokens, attention_mask=attention_mask, encode_patch_size=encode_patch_size, train=train) |
| if self.quantize_mode in ["vq", "mvq", "softvq"]: |
| z_quantized, result_dict = self.quantize(z) |
| elif self.quantize_mode == "vae": |
| posteriors = self.quantize(z) |
| z_quantized = posteriors.sample() |
| result_dict = posteriors |
|
|
| return z_quantized, result_dict |
| |
| def decode(self, z_quantized, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True): |
| decoded = self.decoder(z_quantized, attention_mask=attention_mask, height=height, width=width, decode_patch_size=decode_patch_size, train=train) |
| if self.finetune_decoder: |
| quantized_states = torch.einsum( |
| 'nchw,cd->ndhw', decoded.softmax(1), |
| self.pixel_quantize.embedding.weight) |
| decoded = self.pixel_decoder(quantized_states) |
| return decoded |
| |
| def decode_tokens(self, tokens, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True): |
| if self.quantize_mode in ["vq", "softvq"]: |
| tokens = tokens.squeeze(1) |
| batch, seq_len = tokens.shape |
| z_quantized = self.quantize.get_codebook_entry( |
| tokens.reshape(-1)).reshape(batch, 1, seq_len, -1) |
| z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous() |
| elif self.quantize_mode == "mvq": |
| z_quantized = self.quantize.get_codebook_entry(tokens) |
| elif self.quantize_mode == "vae": |
| z_quantized = tokens |
| z_quantized = z_quantized.to(self.decoder.decoder_embed.weight.dtype) |
| decoded = self.decode(z_quantized, attention_mask=attention_mask, height=height, width=width, decode_patch_size=decode_patch_size, train=train) |
| return decoded |
| |
| def forward(self, x, key_attention_mask=None, height=None, width=None, train=True): |
| if height is None: |
| batch_size, channels, height, width = x.shape |
| z_quantized, result_dict = self.encode(x, attention_mask=key_attention_mask, train=train) |
| z_quantized = z_quantized.to(self.decoder.decoder_embed.weight.dtype) |
| decoded = self.decode(z_quantized, attention_mask=key_attention_mask, height=height, width=width, train=train) |
| return decoded, result_dict |