| from typing import Literal, Optional |
| import json |
| import open_clip |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
| from open_clip import create_model_from_pretrained, create_model |
| from torchvision.transforms import Normalize |
|
|
| from ...ext.autoencoder import AutoEncoderModule |
| from ...ext.mel_converter import get_mel_converter |
| from ...ext.synchformer.synchformer import Synchformer |
| from ...model.utils.distributions import DiagonalGaussianDistribution |
| from shared.utils import files_locator as fl |
|
|
|
|
| def patch_clip(clip_model): |
| |
| |
| def new_encode_text(self, text, normalize: bool = False): |
| cast_dtype = self.transformer.get_cast_dtype() |
|
|
| x = self.token_embedding(text).to(cast_dtype) |
|
|
| x = x + self.positional_embedding.to(cast_dtype) |
| x = self.transformer(x, attn_mask=self.attn_mask) |
| x = self.ln_final(x) |
| return F.normalize(x, dim=-1) if normalize else x |
|
|
| clip_model.encode_text = new_encode_text.__get__(clip_model) |
| return clip_model |
|
|
| def get_model_config(model_name): |
| with open( fl.locate_file("DFN5B-CLIP-ViT-H-14-378/open_clip_config.json"), 'r', encoding='utf-8') as f: |
| return json.load(f)["model_cfg"] |
|
|
| class FeaturesUtils(nn.Module): |
|
|
| def __init__( |
| self, |
| *, |
| tod_vae_ckpt: Optional[str] = None, |
| bigvgan_vocoder_ckpt: Optional[str] = None, |
| synchformer_ckpt: Optional[str] = None, |
| enable_conditions: bool = True, |
| mode=Literal['16k', '44k'], |
| need_vae_encoder: bool = True, |
| ): |
| super().__init__() |
| self.device ="cuda" |
| if enable_conditions: |
| old_get_model_config = open_clip.factory.get_model_config |
| open_clip.factory.get_model_config = get_model_config |
| with open( fl.locate_file("DFN5B-CLIP-ViT-H-14-378/open_clip_config.json"), 'r', encoding='utf-8') as f: |
| override_preprocess = json.load(f)["preprocess_cfg"] |
|
|
| self.clip_model = create_model('DFN5B-CLIP-ViT-H-14-378', pretrained= fl.locate_file('DFN5B-CLIP-ViT-H-14-378/open_clip_pytorch_model.bin'), force_preprocess_cfg= override_preprocess) |
| open_clip.factory.get_model_config = old_get_model_config |
|
|
| |
| self.clip_preprocess = Normalize(mean=[0.48145466, 0.4578275, 0.40821073], |
| std=[0.26862954, 0.26130258, 0.27577711]) |
| self.clip_model = patch_clip(self.clip_model) |
|
|
| self.synchformer = Synchformer() |
| self.synchformer.load_state_dict( |
| torch.load(synchformer_ckpt, weights_only=True, map_location='cpu')) |
|
|
| self.tokenizer = open_clip.get_tokenizer('ViT-H-14-378-quickgelu') |
| else: |
| self.clip_model = None |
| self.synchformer = None |
| self.tokenizer = None |
|
|
| if tod_vae_ckpt is not None: |
| self.mel_converter = get_mel_converter(mode) |
| self.tod = AutoEncoderModule(vae_ckpt_path=tod_vae_ckpt, |
| vocoder_ckpt_path=bigvgan_vocoder_ckpt, |
| mode=mode, |
| need_vae_encoder=need_vae_encoder) |
| else: |
| self.tod = None |
|
|
| def compile(self): |
| if self.clip_model is not None: |
| self.clip_model.encode_image = torch.compile(self.clip_model.encode_image) |
| self.clip_model.encode_text = torch.compile(self.clip_model.encode_text) |
| if self.synchformer is not None: |
| self.synchformer = torch.compile(self.synchformer) |
| self.decode = torch.compile(self.decode) |
| self.vocode = torch.compile(self.vocode) |
|
|
| def train(self, mode: bool) -> None: |
| return super().train(False) |
|
|
| @torch.inference_mode() |
| def encode_video_with_clip(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor: |
| assert self.clip_model is not None, 'CLIP is not loaded' |
| |
| b, t, c, h, w = x.shape |
| assert c == 3 and h == 384 and w == 384 |
| x = self.clip_preprocess(x) |
| x = rearrange(x, 'b t c h w -> (b t) c h w') |
| outputs = [] |
| if batch_size < 0: |
| batch_size = b * t |
| for i in range(0, b * t, batch_size): |
| outputs.append(self.clip_model.encode_image(x[i:i + batch_size], normalize=True)) |
| x = torch.cat(outputs, dim=0) |
| |
| x = rearrange(x, '(b t) d -> b t d', b=b) |
| return x |
|
|
| @torch.inference_mode() |
| def encode_video_with_sync(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor: |
| assert self.synchformer is not None, 'Synchformer is not loaded' |
| |
|
|
| b, t, c, h, w = x.shape |
| assert c == 3 and h == 224 and w == 224 |
|
|
| |
| segment_size = 16 |
| step_size = 8 |
| num_segments = (t - segment_size) // step_size + 1 |
| segments = [] |
| for i in range(num_segments): |
| segments.append(x[:, i * step_size:i * step_size + segment_size]) |
| x = torch.stack(segments, dim=1) |
|
|
| outputs = [] |
| if batch_size < 0: |
| batch_size = b |
| x = rearrange(x, 'b s t c h w -> (b s) 1 t c h w') |
| for i in range(0, b * num_segments, batch_size): |
| outputs.append(self.synchformer(x[i:i + batch_size])) |
| x = torch.cat(outputs, dim=0) |
| x = rearrange(x, '(b s) 1 t d -> b (s t) d', b=b) |
| return x |
|
|
| @torch.inference_mode() |
| def encode_text(self, text: list[str]) -> torch.Tensor: |
| assert self.clip_model is not None, 'CLIP is not loaded' |
| assert self.tokenizer is not None, 'Tokenizer is not loaded' |
| |
| tokens = self.tokenizer(text).to(self.device) |
| return self.clip_model.encode_text(tokens, normalize=True) |
|
|
| @torch.inference_mode() |
| def encode_audio(self, x) -> DiagonalGaussianDistribution: |
| assert self.tod is not None, 'VAE is not loaded' |
| |
| mel = self.mel_converter(x) |
| dist = self.tod.encode(mel) |
|
|
| return dist |
|
|
| @torch.inference_mode() |
| def vocode(self, mel: torch.Tensor) -> torch.Tensor: |
| assert self.tod is not None, 'VAE is not loaded' |
| return self.tod.vocode(mel) |
|
|
| @torch.inference_mode() |
| def decode(self, z: torch.Tensor) -> torch.Tensor: |
| assert self.tod is not None, 'VAE is not loaded' |
| return self.tod.decode(z.transpose(1, 2)) |
|
|
| |
| |
| |
|
|
| @property |
| def dtype(self): |
| return next(self.parameters()).dtype |
|
|