| | |
| | |
| | ''' |
| | @Project :Waveformer-main |
| | @File :CLAPSep.py |
| | @IDE :PyCharm |
| | @Author :Aisaka/Hao Ma @SDU |
| | @Date :2024/2/28 下午1:12 |
| | ''' |
| |
|
| | import torch |
| | import laion_clap |
| | from torchmetrics.audio.snr import( |
| | scale_invariant_signal_noise_ratio as si_snr, |
| | signal_noise_ratio as snr) |
| | from torchmetrics.audio.sdr import( |
| | signal_distortion_ratio as sdr, |
| | scale_invariant_signal_distortion_ratio as si_sdr) |
| | import copy |
| | import loralib as lora |
| | from torchlibrosa import ISTFT, STFT, SpecAugmentation |
| | from torchlibrosa.stft import magphase |
| | import librosa |
| | import pytorch_lightning as pl |
| |
|
| |
|
| | def loss_fn(pred, tgt): |
| | return -0.9 * snr(pred, tgt).mean() - 0.1 * si_snr(pred, tgt).mean() |
| |
|
| |
|
| | def set_module(model, submodule_key, module): |
| | tokens = submodule_key.split('.') |
| | sub_tokens = tokens[:-1] |
| | cur_mod = model |
| | for s in sub_tokens: |
| | cur_mod = getattr(cur_mod, s) |
| | setattr(cur_mod, tokens[-1], module) |
| |
|
| |
|
| | def process_model(model, rank): |
| | for n, module in model.named_modules(): |
| | if 'WindowAttention' in str(type(module)): |
| | for n_, layer in module.named_modules(): |
| | if isinstance(layer, torch.nn.Linear): |
| | lora_layer = lora.Linear(layer.in_features, layer.out_features, r=rank, |
| | bias=hasattr(layer, 'bias'), merge_weights=False) |
| | lora_layer.weight = layer.weight |
| | if hasattr(layer, 'bias'): |
| | lora_layer.bias = layer.bias |
| | set_module(model, n+'.'+n_, lora_layer) |
| | return model |
| |
|
| |
|
| | class LightningModule(pl.LightningModule): |
| | def __init__(self, clap_model, decoder_model, lr, use_lora=False, rank=8, nfft=1024): |
| | super().__init__() |
| | self.phase = decoder_model.phase |
| | self.lr = lr |
| | self.clap_model = clap_model |
| | for p in self.clap_model.parameters(): |
| | p.requires_grad = False |
| | self.audio_branch = copy.deepcopy(self.clap_model.model.audio_branch) |
| | if use_lora: |
| | process_model(self.audio_branch, rank) |
| | lora.mark_only_lora_as_trainable(self.audio_branch, bias='lora_only') |
| |
|
| | self.decoder_model = decoder_model |
| | self.stft = STFT(n_fft=nfft, hop_length=320, |
| | win_length=nfft, window='hann', center=True, pad_mode='reflect', |
| | freeze_parameters=True) |
| | self.istft = ISTFT(n_fft=nfft, hop_length=320, |
| | win_length=nfft, window='hann', center=True, pad_mode='reflect', |
| | freeze_parameters=True) |
| | self.features = self.install_forward_hooks() |
| |
|
| | def training_step(self, batch, batch_idx): |
| | self.clap_model.eval() |
| | self.audio_branch.eval() |
| | |
| | mixed, mixed_resample, pos_cap, neg_cap, gt, pos_sample, neg_sample = batch |
| | real, imag = self.stft(mixed) |
| | mag, cos, sin = magphase(real, imag) |
| | with torch.no_grad(): |
| | a = torch.rand((1,)).type_as(gt) |
| | embed_pos_a, embed_neg_a = torch.chunk( |
| | self.clap_model.get_audio_embedding_from_data(torch.concat([pos_sample, neg_sample], dim=0), |
| | use_tensor=True), dim=0, chunks=2) |
| | embed_pos_t, embed_neg_t = torch.chunk( |
| | self.clap_model.get_text_embedding(pos_cap + neg_cap, use_tensor=True), dim=0, chunks=2) |
| | embed_pos = a * embed_pos_a + (1 - a) * embed_pos_t |
| | embed_neg = a * embed_neg_a + (1 - a) * embed_neg_t |
| | del self.features[:] |
| | self.features.append(mag) |
| | self.audio_branch({"waveform": mixed_resample}) |
| | a = torch.rand((1,)) |
| | if a < 0.25: |
| | loss = self.cal_loss(embed_pos, torch.zeros_like(embed_pos), mag, cos, sin, length=mixed.size(-1), gt=gt) |
| | elif a < 0.5: |
| | loss = self.cal_loss(torch.zeros_like(embed_neg), embed_neg, mag, cos, sin, length=mixed.size(-1), gt=gt) |
| | else: |
| | loss = self.cal_loss(embed_pos, embed_neg, mag, cos, sin, length=mixed.size(-1), gt=gt) |
| | self.log("train_loss", loss.item(), on_epoch=True, prog_bar=True, sync_dist=True, batch_size=len(mixed)) |
| | del self.features[:] |
| | return loss |
| |
|
| | def cal_loss(self, embed_p, embed_n, mag, cos, sin, length, gt): |
| | embed = torch.nn.functional.normalize(torch.concat([embed_p, embed_n], dim=-1), dim=-1) |
| | mask = self.decoder_model(hidden_state=self.features[-1], skip_features=self.features[:-1], embed=embed) |
| | pred = self.wav_reconstruct(mask, mag, cos, sin, length=length) |
| | return loss_fn(pred, gt) |
| |
|
| | def wav_reconstruct(self, mask, mag_x, cos_x, sin_x, length): |
| | |
| | |
| | |
| | |
| | if self.phase: |
| | mag_y = torch.nn.functional.relu_(mag_x * mask[0]) |
| | _, mask_cos, mask_sin = magphase(mask[1], mask[2]) |
| | cos_y = cos_x * mask_cos - sin_x * mask_sin |
| | sin_y = sin_x * mask_cos + cos_x * mask_sin |
| | else: |
| | mag_y = torch.nn.functional.relu_(mag_x * mask) |
| | cos_y = cos_x |
| | sin_y = sin_x |
| | pred = self.istft(mag_y * cos_y, mag_y * sin_y, length=length) |
| | return pred |
| |
|
| | def validation_step(self, batch, batch_idx): |
| | mixed, mixed_resample, label, neg_label, gt, _, _ = batch |
| | real, imag = self.stft(mixed) |
| | mag, cos, sin = magphase(real, imag) |
| | self.features.append(mag) |
| | with torch.no_grad(): |
| | embed_pos = self.clap_model.get_text_embedding(label, use_tensor=True) |
| | embed_neg = self.clap_model.get_text_embedding(neg_label, use_tensor=True) |
| | embed = torch.concat([embed_pos, embed_neg], dim=-1) |
| | self.audio_branch({"waveform": mixed_resample}) |
| | mask = self.decoder_model(hidden_state=self.features[-1], skip_features=self.features[:-1], embed=embed) |
| | pred = self.wav_reconstruct(mask, mag, cos, sin, length=mixed.size(-1)) |
| | loss = si_snr(pred, gt).mean() - si_snr(mixed, gt).mean() |
| | del self.features[:] |
| | self.log("val_loss", loss, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=len(mixed)) |
| | return {"val_loss": loss} |
| |
|
| | def on_test_start(self) -> None: |
| | self.sdr_vals = torch.tensor([]) |
| | self.sdri_vals = torch.tensor([]) |
| | self.sisdr_vals = torch.tensor([]) |
| | self.sisdri_vals = torch.tensor([]) |
| |
|
| | def test_step(self, batch, batch_idx): |
| | mixed, mixed_resample, label, neg_label, gt = batch |
| | real, imag = self.stft(mixed) |
| | mag, cos, sin = magphase(real, imag) |
| | with torch.no_grad(): |
| | embed_pos_bached, embed_neg_bached = torch.chunk(self.clap_model.get_text_embedding(label + neg_label, use_tensor=True), chunks=2, dim=0) |
| | del self.features[:] |
| | |
| | |
| | |
| | |
| | |
| | embed = torch.concat([embed_pos_bached, embed_neg_bached], dim=1) |
| | self.features.append(mag) |
| | self.audio_branch({"waveform": mixed_resample}) |
| | mask = self.decoder_model(hidden_state=self.features[-1], skip_features=self.features[:-1], embed=embed) |
| | pred = self.wav_reconstruct(mask, mag, cos, sin, length=mixed.size(-1)) |
| | sisdr = si_sdr(pred, gt).cpu() |
| | self.sisdr_vals = torch.concat([self.sisdr_vals, sisdr]) |
| | self.sisdri_vals = torch.concat([self.sisdri_vals, sisdr - si_sdr(mixed, gt).cpu()]) |
| | sdr_ = sdr(pred, gt).cpu() |
| | self.sdr_vals = torch.concat([self.sdr_vals, sdr_]) |
| | self.sdri_vals = torch.concat([self.sdri_vals, sdr_ - sdr(mixed, gt).cpu()]) |
| | del self.features[:] |
| | |
| | def on_test_end(self) -> None: |
| | print(f"SDR-mean: {torch.mean(self.sdr_vals).cpu().numpy():.4f}, SDR-std: {torch.std(self.sdr_vals).cpu().numpy():.4f}") |
| | print(f"SDRi-mean: {torch.mean(self.sdri_vals).cpu().numpy():.4f}, SDRi-std: {torch.std(self.sdri_vals).cpu().numpy():.4f}") |
| | print(f"SISDR-mean: {torch.mean(self.sisdr_vals).cpu().numpy():.4f}, SISDR-std: {torch.std(self.sisdr_vals).cpu().numpy():.4f}") |
| | print(f"SISDRi-mean: {torch.mean(self.sisdri_vals).cpu().numpy():.4f}, SISDRi-std: {torch.std(self.sisdri_vals).cpu().numpy():.4f}") |
| | |
| | def configure_optimizers(self): |
| | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) |
| | schedular = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.3, patience=5, |
| | verbose=True, min_lr=5e-6) |
| | return { |
| | "optimizer": optimizer, |
| | "lr_scheduler": { |
| | "scheduler": schedular, |
| | "interval": "epoch", |
| | "monitor": "val_loss" |
| | }, |
| | } |
| |
|
| | def install_forward_hooks(self): |
| | features = [] |
| | spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, |
| | freq_drop_width=8, freq_stripes_num=2) |
| |
|
| | def get_features_list(_, __, output): |
| | features.append(output) |
| |
|
| | def get_features_list_basic_layer(_, __, output): |
| | features.append(output[0]) |
| |
|
| | def spec_augmentation_hook(_, __, out): |
| | out = out.transpose(1, 3) |
| | out = spec_augmenter(out) |
| | return out.transpose(1, 3) |
| |
|
| | def spectrogram_padding(_, __, out): |
| | return torch.nn.functional.pad(out, (0, 0, 0, 1024 - out.size(2))) |
| |
|
| | self.clap_model.model.audio_branch.bn0.register_forward_hook(spec_augmentation_hook) |
| | self.audio_branch.spectrogram_extractor.register_forward_hook(spectrogram_padding) |
| | self.audio_branch.patch_embed.register_forward_hook(get_features_list) |
| | for module in self.audio_branch.layers: |
| | module.register_forward_hook(get_features_list_basic_layer) |
| | return features |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|