Spaces:
Build error
Build error
| import torch | |
| from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanTransformer3DModel | |
| from transformers import AutoTokenizer, T5EncoderModel | |
| from finetrainers.models.utils import _expand_conv3d_with_zeroed_weights | |
| from finetrainers.models.wan import WanControlModelSpecification | |
| class DummyWanControlModelSpecification(WanControlModelSpecification): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| # This needs to be updated for the test to work correctly. | |
| # TODO(aryan): it will not be needed if we hosted the dummy model so that the correct config could be loaded | |
| # with ModelSpecification::_load_configs | |
| self.transformer_config.in_channels = 16 | |
| def load_condition_models(self): | |
| text_encoder = T5EncoderModel.from_pretrained( | |
| "hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") | |
| return {"text_encoder": text_encoder, "tokenizer": tokenizer} | |
| def load_latent_models(self): | |
| torch.manual_seed(0) | |
| vae = AutoencoderKLWan( | |
| base_dim=3, | |
| z_dim=16, | |
| dim_mult=[1, 1, 1, 1], | |
| num_res_blocks=1, | |
| temperal_downsample=[False, True, True], | |
| ) | |
| # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. | |
| # Doing so overrides things like _keep_in_fp32_modules | |
| vae.to(self.vae_dtype) | |
| self.vae_config = vae.config | |
| return {"vae": vae} | |
| def load_diffusion_models(self, new_in_features: int): | |
| torch.manual_seed(0) | |
| transformer = WanTransformer3DModel( | |
| patch_size=(1, 2, 2), | |
| num_attention_heads=2, | |
| attention_head_dim=12, | |
| in_channels=16, | |
| out_channels=16, | |
| text_dim=32, | |
| freq_dim=256, | |
| ffn_dim=32, | |
| num_layers=2, | |
| cross_attn_norm=True, | |
| qk_norm="rms_norm_across_heads", | |
| rope_max_seq_len=32, | |
| ).to(self.transformer_dtype) | |
| transformer.patch_embedding = _expand_conv3d_with_zeroed_weights( | |
| transformer.patch_embedding, new_in_channels=new_in_features | |
| ) | |
| transformer.register_to_config(in_channels=new_in_features) | |
| # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. | |
| # Doing so overrides things like _keep_in_fp32_modules | |
| transformer.to(self.transformer_dtype) | |
| scheduler = FlowMatchEulerDiscreteScheduler() | |
| return {"transformer": transformer, "scheduler": scheduler} | |