| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | import torch.distributed as dist |
| | import torch.multiprocessing as mp |
| | from torch.nn.parallel import DistributedDataParallel as DDP |
| | from torch.utils.data import DataLoader, Dataset, DistributedSampler |
| | import os |
| | import matplotlib.pyplot as plt |
| | from Deep_ANC_model_trim import CRN |
| | import logging |
| | from Pre_processing import Preprocessing |
| | import random |
| | from torch.optim.lr_scheduler import CosineAnnealingLR |
| | from ranger import Ranger |
| | from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts |
| | from torch.optim.lr_scheduler import OneCycleLR |
| | |
| | |
| | |
| | |
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | def custom_loss_function(output, target): |
| | if output.size() != target.size(): |
| | min_size = min(output.size(2), target.size(2)) |
| | output = output[:, :, :min_size, :] |
| | target = target[:, :, :min_size, :] |
| | return torch.mean((output - target) ** 2) |
| | class NoisySpeechDataset(Dataset): |
| | def __init__(self, noisy_dir, clean_dir, subset_size=50000, shuffle=True): |
| | self.noisy_files = sorted([os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir) if f.endswith('.pt')]) |
| | self.clean_files = sorted([os.path.join(clean_dir, f) for f in os.listdir(clean_dir) if f.endswith('.pt')]) |
| | assert len(self.noisy_files) == len(self.clean_files), "Mismatched noisy and clean datasets" |
| |
|
| | |
| | if shuffle: |
| | combined = list(zip(self.noisy_files, self.clean_files)) |
| | random.shuffle(combined) |
| | self.noisy_files, self.clean_files = zip(*combined) |
| |
|
| | |
| | subset_size = min(subset_size, len(self.noisy_files)) |
| | self.noisy_files = self.noisy_files[:subset_size] |
| | self.clean_files = self.clean_files[:subset_size] |
| |
|
| | def __len__(self): |
| | return len(self.noisy_files) |
| |
|
| | def __getitem__(self, idx): |
| | noisy_spectrogram = torch.load(self.noisy_files[idx], weights_only=True) |
| | clean_spectrogram = torch.load(self.clean_files[idx], weights_only=True) |
| | return noisy_spectrogram, clean_spectrogram |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | def snr_improvement(noisy, clean, enhanced): |
| | min_size = min(noisy.size(2), clean.size(2), enhanced.size(2)) |
| | noisy = noisy[:, :, :min_size, :] |
| | clean = clean[:, :, :min_size, :] |
| | enhanced = enhanced[:, :, :min_size, :] |
| | |
| | noise = noisy - clean |
| | noise_est = enhanced - clean |
| |
|
| | |
| | noise_power = torch.mean(noise ** 2) |
| | noise_est_power = torch.mean(noise_est ** 2) |
| |
|
| | if noise_power == 0 or noise_est_power == 0: |
| | return torch.tensor(0.0) |
| |
|
| | snr_before = torch.mean(clean ** 2) / noise_power |
| | snr_after = torch.mean(clean ** 2) / noise_est_power |
| | |
| | return 10 * torch.log10(snr_after / snr_before) |
| |
|
| | def plot_metrics(train_metrics, val_metrics, metric_name): |
| | epochs = range(1, len(train_metrics) + 1) |
| | plt.plot(epochs, train_metrics, 'bo', label=f'Training {metric_name}') |
| | plt.plot(epochs, val_metrics, 'b', label=f'Validation {metric_name}') |
| | plt.title(f'Training and Validation {metric_name}') |
| | plt.xlabel('Epochs') |
| | plt.ylabel(metric_name) |
| | plt.legend() |
| | plt.show() |
| |
|
| |
|
| | def train_model(rank, world_size, model, train_loader, val_loader, num_epochs, learning_rate, save_path, best_save_path, checkpoint_path=None): |
| | try: |
| | |
| | torch.autograd.set_detect_anomaly(True) |
| |
|
| | |
| | torch.cuda.set_device(rank) |
| | model = model.to(rank) |
| | model = DDP(model, device_ids=[rank]) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True) |
| | |
| | |
| | |
| | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', |
| | factor=0.1, patience=20, |
| | verbose=True) |
| | |
| | |
| | start_epoch = 0 |
| | best_val_loss = float('inf') |
| | best_val_snr_improvement = float('-inf') |
| |
|
| | |
| | if checkpoint_path and os.path.exists(checkpoint_path): |
| | try: |
| | checkpoint = torch.load(checkpoint_path, map_location=torch.device(f'cuda:{rank}')) |
| | print(f"Checkpoint keys: {checkpoint.keys()}") |
| |
|
| | |
| | model.load_state_dict(checkpoint) |
| | logger.info(f"Model state loaded directly from checkpoint.") |
| |
|
| | |
| | if 'optimizer_state_dict' in checkpoint: |
| | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| |
|
| | if 'scheduler_state_dict' in checkpoint: |
| | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
| |
|
| | |
| | start_epoch = checkpoint.get('epoch', 0) + 1 |
| | best_val_loss = checkpoint.get('best_val_loss', float('inf')) |
| | best_val_snr_improvement = checkpoint.get('best_val_snr_improvement', float('-inf')) |
| | logger.info(f"Resuming training from epoch {start_epoch}") |
| | |
| | except Exception as e: |
| | logger.error(f"Error loading checkpoint: {e}") |
| | raise e |
| |
|
| | |
| | model.train() |
| | training_snr_improvements = [] |
| | validation_snr_improvements = [] |
| |
|
| | for epoch in range(start_epoch, start_epoch + num_epochs): |
| | running_loss = 0.0 |
| | train_snr_improvement = 0.0 |
| | total_samples = 0 |
| | batch_snr_improvements = [] |
| |
|
| | for i, (noisy_spectrogram, clean_spectrogram) in enumerate(train_loader): |
| | noisy_spectrogram = noisy_spectrogram.cuda(rank, non_blocking=True) |
| | clean_spectrogram = clean_spectrogram.cuda(rank, non_blocking=True) |
| |
|
| | optimizer.zero_grad() |
| |
|
| | |
| | with torch.amp.autocast(device_type='cuda'): |
| | output = model(noisy_spectrogram) |
| | loss = custom_loss_function(output, clean_spectrogram) |
| | |
| | |
| | if torch.isnan(loss).any() or torch.isinf(loss).any(): |
| | print(f"NaN or Inf detected in loss at iteration {i}, epoch {epoch}") |
| | continue |
| |
|
| | loss.backward() |
| |
|
| | |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
| |
|
| | optimizer.step() |
| |
|
| | running_loss += loss.item() |
| |
|
| | |
| | batch_snr_improvement = 0.0 |
| | for j in range(noisy_spectrogram.size(0)): |
| | single_snr_improvement = snr_improvement( |
| | noisy_spectrogram[j:j+1], clean_spectrogram[j:j+1], output[j:j+1] |
| | ).item() |
| | batch_snr_improvement += single_snr_improvement |
| |
|
| | batch_snr_improvement /= noisy_spectrogram.size(0) |
| | batch_snr_improvements.append(batch_snr_improvement) |
| | total_samples += noisy_spectrogram.size(0) |
| |
|
| | |
| | training_snr_improvement_avg = sum(batch_snr_improvements) / len(batch_snr_improvements) |
| | training_snr_improvements.append(training_snr_improvement_avg) |
| |
|
| | print(f"Epoch {epoch+1}, Training SNR Improvement: {training_snr_improvement_avg}") |
| | print(f"Epoch {epoch+1}, Total Samples Processed: {total_samples}") |
| |
|
| | |
| | model.eval() |
| | val_loss = 0.0 |
| | val_snr_improvement = 0.0 |
| | with torch.no_grad(): |
| | for noisy_spectrogram, clean_spectrogram in val_loader: |
| | noisy_spectrogram = noisy_spectrogram.cuda(rank, non_blocking=True) |
| | clean_spectrogram = clean_spectrogram.cuda(rank, non_blocking=True) |
| | with torch.amp.autocast(device_type='cuda'): |
| | output = model(noisy_spectrogram) |
| | loss = custom_loss_function(output, clean_spectrogram) |
| |
|
| | val_loss += loss.item() |
| | val_snr_improvement += snr_improvement(noisy_spectrogram, clean_spectrogram, output).item() |
| |
|
| | val_loss /= len(val_loader) |
| | val_snr_improvement /= len(val_loader) |
| | validation_snr_improvements.append(val_snr_improvement) |
| |
|
| | print(f"Epoch {epoch+1}, Validation Loss: {val_loss}, Validation SNR Improvement: {val_snr_improvement}") |
| | model.train() |
| |
|
| | |
| | if rank == 0: |
| | if (epoch + 1) % 50 == 0: |
| | torch.save(model.state_dict(), save_path) |
| | print(f"Model saved at epoch {epoch+1}") |
| |
|
| | if val_loss < best_val_loss: |
| | best_val_loss = val_loss |
| | torch.save(model.state_dict(), best_save_path) |
| | print(f"Best model saved at epoch {epoch+1} with validation loss {best_val_loss}") |
| |
|
| | if val_snr_improvement > best_val_snr_improvement: |
| | best_val_snr_improvement = val_snr_improvement |
| |
|
| | |
| | scheduler.step(val_loss) |
| |
|
| | if rank == 0: |
| | print(f"Training complete for batch size {train_loader.batch_size}, learning rate {learning_rate}, epochs {num_epochs}") |
| | print(f"Best Validation Loss: {best_val_loss}, Best Validation SNR Improvement: {best_val_snr_improvement}") |
| | plot_metrics(training_snr_improvements, validation_snr_improvements, 'SNR Improvement') |
| |
|
| | except Exception as e: |
| | print(f"Rank {rank} encountered an error: {e}") |
| | finally: |
| | torch.cuda.synchronize() |
| | cleanup() |
| |
|
| | def setup(rank, world_size): |
| | logger.info(f"Setting up distributed training on rank {rank}") |
| | dist.init_process_group("nccl", rank=rank, world_size=world_size) |
| | torch.cuda.set_device(rank) |
| | |
| | def cleanup(): |
| | try: |
| | dist.destroy_process_group() |
| | except Exception as e: |
| | print(f"Error during cleanup: {e}") |
| | |
| | def main_worker(rank, world_size, noisy_dir, clean_dir, save_dir, num_epochs, learning_rate, batch_size, checkpoint_path): |
| | try: |
| | setup(rank, world_size) |
| |
|
| | |
| | |
| | |
| |
|
| | dataset = NoisySpeechDataset(os.path.join(save_dir, 'noisy'), os.path.join(save_dir, 'clean'), subset_size=50000) |
| |
|
| | train_size = int(0.8 * len(dataset)) |
| | val_size = len(dataset) - train_size |
| | train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) |
| |
|
| | train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank) |
| | val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False) |
| |
|
| | train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=2) |
| | val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler) |
| |
|
| | model = CRN() |
| |
|
| | save_path = f"/home/siddharth/Sid/ASR/ANC/DEEP_ANC_MODEL_trim_bs{batch_size}_lr{learning_rate}_ep{num_epochs}_og_trial.pth" |
| | best_save_path = f"/home/siddharth/Sid/ASR/ANC/DEEP_ANC_MODEL_best_bs{batch_size}_lr{learning_rate}_ep{num_epochs}_og_trial.pth" |
| |
|
| | train_model(rank, world_size, model, train_loader, val_loader, num_epochs, learning_rate, save_path, best_save_path, checkpoint_path) |
| | |
| | except Exception as e: |
| | logger.error(f"An error occurred on rank {rank}: {e}") |
| | finally: |
| | cleanup() |
| |
|