| | import os |
| | |
| | import random |
| | import hydra |
| | import numpy as np |
| | import librosa |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.optim as optim |
| | import pytorch_lightning as pl |
| | from vq import CodecEncoder, CodecDecoderVocos |
| | from module import HiFiGANMultiPeriodDiscriminator, SpecDiscriminator |
| | from criterions import GANLoss, MultiResolutionMelSpectrogramLoss, MultiResolutionSTFTLoss |
| | from common.schedulers import WarmupLR |
| | from transformers import AutoModel |
| | from vq.module import SemanticDecoder,SemanticEncoder |
| | from transformers import AutoFeatureExtractor, Wav2Vec2BertModel |
| | import sys |
| | |
| | |
| | |
| |
|
| |
|
| | from transformers import AutoModel, AutoFeatureExtractor |
| |
|
| |
|
| | class CodecLightningModule(pl.LightningModule): |
| | def __init__(self, cfg): |
| | super().__init__() |
| | self.cfg = cfg |
| | self.ocwd = hydra.utils.get_original_cwd() |
| | self.construct_model() |
| | self.construct_criteria() |
| | self.save_hyperparameters() |
| | self.automatic_optimization = False |
| |
|
| | def construct_model(self): |
| | |
| | |
| | enccfg = self.cfg.model.codec_encoder |
| |
|
| | |
| | self.CodecEnc = CodecEncoder( |
| | |
| | ngf=enccfg.ngf, |
| | up_ratios=enccfg.up_ratios, |
| | dilations=enccfg.dilations, |
| | hidden_dim=enccfg['hidden_dim'], |
| | depth=enccfg['depth'], |
| | heads=enccfg['heads'], |
| | pos_meb_dim=enccfg['pos_meb_dim'], |
| | ) |
| |
|
| | |
| | deccfg = self.cfg.model.codec_decoder |
| |
|
| | self.generator = CodecDecoderVocos( |
| | hidden_dim=deccfg.hidden_dim, |
| | depth=deccfg.depth, |
| | heads=deccfg.heads, |
| | pos_meb_dim=deccfg.pos_meb_dim, |
| | hop_length=960, |
| | vq_num_quantizers=deccfg.vq_num_quantizers, |
| | vq_dim=deccfg.vq_dim, |
| | vq_commit_weight=deccfg.vq_commit_weight, |
| | vq_weight_init=deccfg.vq_weight_init, |
| | vq_full_commit_loss=deccfg.vq_full_commit_loss, |
| | codebook_size=deccfg.codebook_size, |
| | codebook_dim=deccfg.codebook_dim , |
| | |
| | ) |
| | |
| | |
| |
|
| | |
| | mpdcfg = self.cfg.model.mpd |
| | self.discriminator = HiFiGANMultiPeriodDiscriminator( |
| | periods=mpdcfg.periods, |
| | max_downsample_channels=mpdcfg.max_downsample_channels, |
| | channels=mpdcfg.channels, |
| | channel_increasing_factor=mpdcfg.channel_increasing_factor, |
| | ) |
| |
|
| | |
| | mstftcfg = self.cfg.model.mstft |
| | self.spec_discriminator = SpecDiscriminator( |
| | stft_params=mstftcfg.stft_params, |
| | in_channels=mstftcfg.in_channels, |
| | out_channels=mstftcfg.out_channels, |
| | kernel_sizes=mstftcfg.kernel_sizes, |
| | channels=mstftcfg.channels, |
| | max_downsample_channels=mstftcfg.max_downsample_channels, |
| | downsample_scales=mstftcfg.downsample_scales, |
| | use_weight_norm=mstftcfg.use_weight_norm, |
| | ) |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | self.speaker_model = AutoModel.from_pretrained("microsoft/wavlm-base-plus-sv") |
| | self.speaker_feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/wavlm-base-plus-sv") |
| | self.speaker_model.eval() |
| | self.speaker_model.requires_grad_(False) |
| |
|
| | self.fc_prior = nn.Linear(1024 + 1024, deccfg.vq_dim, ) |
| | self.fc_post_a = nn.Linear(deccfg.vq_dim, deccfg.hidden_dim ) |
| | self.fc_post_s = nn.Linear(deccfg.vq_dim, 1024) |
| |
|
| | self.SemanticDecoder_module = SemanticDecoder(1024, 1024, 1024) |
| | self.SemanticEncoder_module = SemanticEncoder(1024, 1024, 1024) |
| | self.semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", output_hidden_states=True) |
| | self.semantic_model.eval() |
| | self.semantic_model.requires_grad_(False) |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | def construct_criteria(self): |
| | cfg = self.cfg.train |
| | self.criteria = nn.ModuleDict() |
| | if cfg.use_mel_loss: |
| | self.criteria['mel_loss'] = MultiResolutionMelSpectrogramLoss(sample_rate=self.cfg.preprocess.audio.sr) |
| | if cfg.use_stft_loss: |
| | self.criteria['stft_loss'] = MultiResolutionSTFTLoss( |
| | fft_sizes=cfg.stft_loss_params.fft_sizes, |
| | hop_sizes=cfg.stft_loss_params.hop_sizes, |
| | win_sizes=cfg.stft_loss_params.win_lengths |
| | ) |
| | if cfg.use_feat_match_loss: |
| | self.criteria['fm_loss'] = nn.L1Loss() |
| | self.criteria['gan_loss'] = GANLoss() |
| | self.criteria['l1_loss'] = nn.L1Loss() |
| | self.criteria['l2_loss'] = nn.MSELoss() |
| | print(self.criteria) |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def forward(self, batch): |
| | wav = batch['wav'] |
| | feats = batch['feats'] |
| | |
| | vq_emb = self.CodecEnc(wav.unsqueeze(1)) |
| |
|
| | with torch.no_grad(): |
| | semantic_target = self.semantic_model(feats) |
| | semantic_target = semantic_target.hidden_states[16].detach() |
| |
|
| | T_codec = vq_emb.shape[1] |
| | T_semantic = semantic_target.shape[1] |
| | |
| |
|
| | semantic_target_for_loss = semantic_target.clone() |
| | |
| | if T_codec != T_semantic: |
| | semantic_target = F.interpolate( |
| | semantic_target.transpose(1, 2), |
| | size=T_codec, |
| | mode='linear', |
| | align_corners=False |
| | ).transpose(1, 2) |
| |
|
| | semantic_target_transposed = semantic_target.transpose(1, 2) |
| | semantic_target_processed = self.SemanticEncoder_module(semantic_target_transposed) |
| | semantic_target_processed = semantic_target_processed.transpose(1, 2) |
| | |
| | vq_emb = torch.cat([semantic_target_processed, vq_emb], dim=2) |
| | vq_emb = self.fc_prior(vq_emb) |
| | |
| | vq_emb = vq_emb.transpose(1, 2) |
| | vq_post_emb, vq_code, vq_loss = self.generator(vq_emb, vq=True) |
| | |
| | vq_post_emb_t = vq_post_emb.transpose(1, 2) |
| | |
| | semantic_recon = self.fc_post_s(vq_post_emb_t) |
| | |
| | semantic_recon_transposed = semantic_recon.transpose(1, 2) |
| | semantic_recon = self.SemanticDecoder_module(semantic_recon_transposed) |
| | semantic_recon = semantic_recon.transpose(1, 2) |
| | |
| | |
| | if T_codec != T_semantic: |
| | semantic_recon_for_loss = F.interpolate( |
| | semantic_recon.transpose(1, 2), |
| | size=T_semantic, |
| | mode='linear', |
| | align_corners=False |
| | ).transpose(1, 2) |
| | else: |
| | semantic_recon_for_loss = semantic_recon |
| |
|
| | |
| | gen_input = self.fc_post_a(vq_post_emb_t) |
| | y_, _ = self.generator(gen_input.transpose(1, 2), vq=False) |
| | y = wav.unsqueeze(1) |
| |
|
| | output = { |
| | 'gt_wav': y, |
| | 'gen_wav': y_, |
| | 'vq_loss': vq_loss, |
| | 'vq_code': vq_code, |
| | 'semantic_recon_loss': F.mse_loss(semantic_recon_for_loss, semantic_target_for_loss), |
| | } |
| | return output |
| | |
| | @torch.inference_mode() |
| | def inference(self, wav): |
| | vq_emb = self.CodecEnc(wav.unsqueeze(1)) |
| | vq_post_emb, vq_code, vq_loss = self.generator(vq_emb, vq=True) |
| | y_ = self.generator(vq_post_emb, vq=False).squeeze(1) |
| | return y_ |
| |
|
| | def compute_disc_loss(self, batch, output): |
| | y, y_ = output['gt_wav'], output['gen_wav'] |
| | y_ = y_.detach() |
| | p = self.discriminator(y) |
| | p_ = self.discriminator(y_) |
| |
|
| | real_loss_list, fake_loss_list = [], [] |
| | for i in range(len(p)): |
| | real_loss, fake_loss = self.criteria['gan_loss'].disc_loss(p[i][-1], p_[i][-1]) |
| | real_loss_list.append(real_loss) |
| | fake_loss_list.append(fake_loss) |
| |
|
| | if hasattr(self, 'spec_discriminator'): |
| | sd_p = self.spec_discriminator(y) |
| | sd_p_ = self.spec_discriminator(y_) |
| |
|
| | for i in range(len(sd_p)): |
| | real_loss, fake_loss = self.criteria['gan_loss'].disc_loss(sd_p[i][-1], sd_p_[i][-1]) |
| | real_loss_list.append(real_loss) |
| | fake_loss_list.append(fake_loss) |
| |
|
| | real_loss = sum(real_loss_list) |
| | fake_loss = sum(fake_loss_list) |
| |
|
| | disc_loss = real_loss + fake_loss |
| | disc_loss = self.cfg.train.lambdas.lambda_disc * disc_loss |
| |
|
| | output = { |
| | 'real_loss': real_loss, |
| | 'fake_loss': fake_loss, |
| | 'disc_loss': disc_loss, |
| | } |
| | return output |
| |
|
| | def compute_gen_loss(self, batch, output): |
| | y, y_ = output['gt_wav'], output['gen_wav'] |
| | vq_loss, vq_code = output['vq_loss'], output['vq_code'] |
| | semantic_recon_loss = output['semantic_recon_loss'] |
| | |
| | |
| | gen_loss = 0.0 |
| | self.set_discriminator_gradients(False) |
| | output_dict = {} |
| | cfg = self.cfg.train |
| |
|
| | |
| | if cfg.use_mel_loss: |
| | mel_loss = self.criteria['mel_loss'](y_.squeeze(1), y.squeeze(1)) |
| | gen_loss += mel_loss * cfg.lambdas.lambda_mel_loss |
| | output_dict['mel_loss'] = mel_loss |
| |
|
| | |
| | p_ = self.discriminator(y_) |
| | adv_loss_list = [] |
| | for i in range(len(p_)): |
| | adv_loss_list.append(self.criteria['gan_loss'].gen_loss(p_[i][-1])) |
| | if hasattr(self, 'spec_discriminator'): |
| | sd_p_ = self.spec_discriminator(y_) |
| | for i in range(len(sd_p_)): |
| | adv_loss_list.append(self.criteria['gan_loss'].gen_loss(sd_p_[i][-1])) |
| | adv_loss = sum(adv_loss_list) |
| | gen_loss += adv_loss * cfg.lambdas.lambda_adv |
| | output_dict['adv_loss'] = adv_loss |
| |
|
| | |
| | if cfg.use_feat_match_loss: |
| | fm_loss = 0.0 |
| | with torch.no_grad(): |
| | p = self.discriminator(y) |
| | for i in range(len(p_)): |
| | for j in range(len(p_[i]) - 1): |
| | fm_loss += self.criteria['fm_loss'](p_[i][j], p[i][j].detach()) |
| | gen_loss += fm_loss * cfg.lambdas.lambda_feat_match_loss |
| | output_dict['fm_loss'] = fm_loss |
| | if hasattr(self, 'spec_discriminator'): |
| | spec_fm_loss = 0.0 |
| | with torch.no_grad(): |
| | sd_p = self.spec_discriminator(y) |
| | for i in range(len(sd_p_)): |
| | for j in range(len(sd_p_[i]) - 1): |
| | spec_fm_loss += self.criteria['fm_loss'](sd_p_[i][j], sd_p[i][j].detach()) |
| | gen_loss += spec_fm_loss * cfg.lambdas.lambda_feat_match_loss |
| | output_dict['spec_fm_loss'] = spec_fm_loss |
| |
|
| | |
| | if vq_loss is not None: |
| | vq_loss = sum(vq_loss) |
| | gen_loss += vq_loss |
| | output_dict['vq_loss'] = vq_loss |
| |
|
| | |
| | output_dict['semantic_recon_loss'] = semantic_recon_loss |
| | gen_loss += output_dict['semantic_recon_loss'] * cfg.lambdas.lambda_semantic_loss |
| |
|
| | |
| | |
| | |
| | |
| | self.set_discriminator_gradients(True) |
| | output_dict['gen_loss'] = gen_loss |
| | return output_dict |
| |
|
| | def training_step(self, batch, batch_idx): |
| | output = self(batch) |
| |
|
| | gen_opt, disc_opt = self.optimizers() |
| | gen_sche, disc_sche = self.lr_schedulers() |
| |
|
| | |
| | disc_losses = self.compute_disc_loss(batch, output) |
| | disc_loss = disc_losses['disc_loss'] |
| | disc_opt.zero_grad() |
| | self.manual_backward(disc_loss) |
| | self.clip_gradients( |
| | disc_opt, |
| | gradient_clip_val=self.cfg.train.disc_grad_clip, |
| | gradient_clip_algorithm='norm' |
| | ) |
| | disc_opt.step() |
| | disc_sche.step() |
| |
|
| | |
| | gen_losses = self.compute_gen_loss(batch, output) |
| | gen_loss = gen_losses['gen_loss'] |
| | gen_opt.zero_grad() |
| | self.manual_backward(gen_loss) |
| | self.clip_gradients( |
| | gen_opt, |
| | gradient_clip_val=self.cfg.train.gen_grad_clip, |
| | gradient_clip_algorithm='norm' |
| | ) |
| | gen_opt.step() |
| | gen_sche.step() |
| |
|
| | |
| | self.log_dict( |
| | disc_losses, |
| | on_step=True, |
| | on_epoch=True, |
| | prog_bar=True, |
| | logger=True, |
| | batch_size=self.cfg.dataset.train.batch_size, |
| | sync_dist=True |
| | ) |
| | self.log_dict( |
| | gen_losses, |
| | on_step=True, |
| | on_epoch=True, |
| | prog_bar=True, |
| | logger=True, |
| | batch_size=self.cfg.dataset.train.batch_size, |
| | sync_dist=True |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | def validation_step(self, batch, batch_idx): |
| | output = self(batch) |
| | y = output['gt_wav'] |
| | y_ = output['gen_wav'] |
| | |
| | |
| | y_audio = y.squeeze(1).cpu().numpy() |
| | y_recon_audio = y_.squeeze(1).cpu().numpy() |
| | |
| | embeddings1_list = [] |
| | embeddings2_list = [] |
| | |
| | |
| | for i in range(y_audio.shape[0]): |
| | |
| | y_16k = librosa.resample(y_audio[i], orig_sr=self.cfg.preprocess.audio.sr, target_sr=16000) |
| | y_recon_16k = librosa.resample(y_recon_audio[i], orig_sr=self.cfg.preprocess.audio.sr, target_sr=16000) |
| | |
| | |
| | inputs1 = self.speaker_feature_extractor( |
| | y_16k, |
| | sampling_rate=16000, |
| | return_tensors="pt" |
| | ).to(self.device) |
| | |
| | inputs2 = self.speaker_feature_extractor( |
| | y_recon_16k, |
| | sampling_rate=16000, |
| | return_tensors="pt" |
| | ).to(self.device) |
| | |
| | |
| | with torch.no_grad(): |
| | outputs1 = self.speaker_model(**inputs1) |
| | outputs2 = self.speaker_model(**inputs2) |
| | |
| | |
| | embedding1 = torch.mean(outputs1.last_hidden_state, dim=1) |
| | embedding2 = torch.mean(outputs2.last_hidden_state, dim=1) |
| | |
| | |
| | embedding1 = F.normalize(embedding1, p=2, dim=1) |
| | embedding2 = F.normalize(embedding2, p=2, dim=1) |
| | |
| | embeddings1_list.append(embedding1) |
| | embeddings2_list.append(embedding2) |
| | |
| | |
| | embeddings1 = torch.cat(embeddings1_list, dim=0) |
| | embeddings2 = torch.cat(embeddings2_list, dim=0) |
| | |
| | |
| | sim = F.cosine_similarity(embeddings1, embeddings2) |
| | sim = sim.mean() |
| | |
| | self.log('val/sim', sim, on_step=False, on_epoch=True, prog_bar=True, logger=True) |
| | |
| | return {'sim': sim} |
| |
|
| | |
| |
|
| | def test_step(self, batch, batch_idx): |
| | |
| | pass |
| |
|
| | def configure_optimizers(self): |
| | from itertools import chain |
| |
|
| | |
| | disc_params = self.discriminator.parameters() |
| | |
| | disc_params = chain(disc_params, self.spec_discriminator.parameters()) |
| |
|
| | |
| | gen_params = chain( |
| | self.CodecEnc.parameters(), |
| | self.generator.parameters(), |
| | |
| | self.fc_prior.parameters(), |
| | self.fc_post_a.parameters(), |
| | self.fc_post_s.parameters(), |
| | self.SemanticDecoder_module.parameters(), |
| | self.SemanticEncoder_module.parameters() |
| | ) |
| |
|
| | |
| | gen_opt = optim.AdamW(gen_params, **self.cfg.train.gen_optim_params) |
| | disc_opt = optim.AdamW(disc_params, **self.cfg.train.disc_optim_params) |
| |
|
| | |
| | gen_sche = WarmupLR(gen_opt, **self.cfg.train.gen_schedule_params) |
| | disc_sche = WarmupLR(disc_opt, **self.cfg.train.disc_schedule_params) |
| |
|
| | print(f'Generator optim: {gen_opt}') |
| | print(f'Discriminator optim: {disc_opt}') |
| |
|
| | return [gen_opt, disc_opt], [gen_sche, disc_sche] |
| |
|
| | def set_discriminator_gradients(self, flag=True): |
| | for p in self.discriminator.parameters(): |
| | p.requires_grad = flag |
| |
|
| | if hasattr(self, 'spec_discriminator'): |
| | for p in self.spec_discriminator.parameters(): |
| | p.requires_grad = flag |
| |
|