| import numpy as np |
| import torch |
| from torchvision import transforms |
| from tqdm import tqdm |
| from PIL import Image |
| import soundfile as sf |
| from mel_module import Mel |
|
|
| class Generator: |
| def __init__(self, config, unet, scheduler, vae, embedding, progress_callback=None): |
| self.config = config |
| self.unet = unet |
| self.scheduler = scheduler |
| self.vae = vae |
| self.embedding = embedding |
| self.progress_callback = progress_callback |
|
|
| def tensor_to_mel(self, tensor): |
| denormalize = transforms.Normalize( |
| mean=[-m/s for m, s in zip([0.5], [0.5])], |
| std=[1/s for s in [0.5]] |
| ) |
| dn_tensor= denormalize(tensor.detach().cpu()) |
| s = np.array(dn_tensor.squeeze())*255 |
| return Mel(spectrogram=s) |
|
|
| def generate(self): |
| with torch.no_grad(): |
| uncond_image = torch.zeros((1, 1, self.config.image_size, self.config.image_size), device=self.config.device) |
| mu, log_var = self.vae.encode(uncond_image) |
| uncond_latent = torch.cat((mu, log_var), dim=1) |
| uncond_latent = uncond_latent.unsqueeze(0) |
| print("uncond", uncond_latent.shape) |
|
|
| embeddings = torch.cat([uncond_latent, self.embedding]) |
|
|
| generator = torch.Generator(device=self.config.device) |
|
|
| noise = torch.randn( |
| (1, 1, self.config.image_size, self.config.image_size), |
| generator=generator, |
| device=self.config.device, |
| ) |
|
|
| total_steps = len(self.scheduler.timesteps) |
|
|
| for i, t in enumerate(self.progress_callback.tqdm(self.scheduler.timesteps)): |
| image_model_input = torch.cat([noise] * 2) |
| image_model_input = self.scheduler.scale_model_input(image_model_input, timestep=t) |
|
|
| with torch.no_grad(): |
| noise_pred = self.unet(image_model_input, t, encoder_hidden_states=embeddings).sample |
| noise_pred_uncond, noise_pred_img = noise_pred.chunk(2) |
| noise_pred = noise_pred_uncond + self.config.guidance_scale * (noise_pred_img - noise_pred_uncond) |
| noise = self.scheduler.step(noise_pred, t, noise).prev_sample |
|
|
| image_tensor = noise.squeeze(1) |
| mel = self.tensor_to_mel(image_tensor) |
| mel.save_audio() |
|
|