| """ |
| PlainMLP vs ResMLP Comparison on Distant Identity Task |
| |
| This experiment demonstrates the vanishing gradient problem in deep networks |
| and how residual connections solve it. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from typing import Dict, List, Tuple |
| import json |
|
|
| |
| torch.manual_seed(42) |
| np.random.seed(42) |
|
|
| |
| NUM_LAYERS = 20 |
| HIDDEN_DIM = 64 |
| NUM_SAMPLES = 1024 |
| TRAINING_STEPS = 500 |
| LEARNING_RATE = 1e-3 |
| BATCH_SIZE = 64 |
|
|
| print(f"[Config] Layers: {NUM_LAYERS}, Hidden Dim: {HIDDEN_DIM}") |
| print(f"[Config] Samples: {NUM_SAMPLES}, Steps: {TRAINING_STEPS}, LR: {LEARNING_RATE}") |
|
|
|
|
| class PlainMLP(nn.Module): |
| """Plain MLP: x = ReLU(Linear(x)) for each layer""" |
| |
| def __init__(self, dim: int, num_layers: int): |
| super().__init__() |
| self.layers = nn.ModuleList() |
| for _ in range(num_layers): |
| layer = nn.Linear(dim, dim) |
| |
| nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu') |
| nn.init.zeros_(layer.bias) |
| self.layers.append(layer) |
| self.activation = nn.ReLU() |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| for layer in self.layers: |
| x = self.activation(layer(x)) |
| return x |
|
|
|
|
| class ResMLP(nn.Module): |
| """Residual MLP: x = x + ReLU(Linear(x)) for each layer""" |
| |
| def __init__(self, dim: int, num_layers: int): |
| super().__init__() |
| self.layers = nn.ModuleList() |
| for _ in range(num_layers): |
| layer = nn.Linear(dim, dim) |
| |
| nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu') |
| nn.init.zeros_(layer.bias) |
| self.layers.append(layer) |
| self.activation = nn.ReLU() |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| for layer in self.layers: |
| x = x + self.activation(layer(x)) |
| return x |
|
|
|
|
| def generate_identity_data(num_samples: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Generate synthetic data where Y = X, with X ~ U(-1, 1)""" |
| X = torch.empty(num_samples, dim).uniform_(-1, 1) |
| Y = X.clone() |
| return X, Y |
|
|
|
|
| def train_model(model: nn.Module, X: torch.Tensor, Y: torch.Tensor, |
| steps: int, lr: float, batch_size: int) -> List[float]: |
| """Train model and record loss at each step""" |
| optimizer = torch.optim.Adam(model.parameters(), lr=lr) |
| criterion = nn.MSELoss() |
| losses = [] |
| |
| num_samples = X.shape[0] |
| |
| for step in range(steps): |
| |
| indices = torch.randint(0, num_samples, (batch_size,)) |
| batch_x = X[indices] |
| batch_y = Y[indices] |
| |
| |
| optimizer.zero_grad() |
| output = model(batch_x) |
| loss = criterion(output, batch_y) |
| |
| |
| loss.backward() |
| optimizer.step() |
| |
| losses.append(loss.item()) |
| |
| if step % 100 == 0: |
| print(f" Step {step}/{steps}, Loss: {loss.item():.6f}") |
| |
| return losses |
|
|
|
|
| class ActivationGradientHook: |
| """Hook to capture activations and gradients at each layer""" |
| |
| def __init__(self): |
| self.activations: List[torch.Tensor] = [] |
| self.gradients: List[torch.Tensor] = [] |
| self.handles = [] |
| |
| def register_hooks(self, model: nn.Module): |
| """Register forward and backward hooks on each layer""" |
| for layer in model.layers: |
| |
| handle_fwd = layer.register_forward_hook(self._forward_hook) |
| |
| handle_bwd = layer.register_full_backward_hook(self._backward_hook) |
| self.handles.extend([handle_fwd, handle_bwd]) |
| |
| def _forward_hook(self, module, input, output): |
| self.activations.append(output.detach().clone()) |
| |
| def _backward_hook(self, module, grad_input, grad_output): |
| |
| self.gradients.append(grad_output[0].detach().clone()) |
| |
| def clear(self): |
| self.activations = [] |
| self.gradients = [] |
| |
| def remove_hooks(self): |
| for handle in self.handles: |
| handle.remove() |
| self.handles = [] |
| |
| def get_activation_stats(self) -> Tuple[List[float], List[float]]: |
| """Get mean and std of activations for each layer""" |
| means = [act.mean().item() for act in self.activations] |
| stds = [act.std().item() for act in self.activations] |
| return means, stds |
| |
| def get_gradient_norms(self) -> List[float]: |
| """Get L2 norm of gradients for each layer""" |
| |
| norms = [grad.norm(2).item() for grad in reversed(self.gradients)] |
| return norms |
|
|
|
|
| def analyze_final_state(model: nn.Module, dim: int, batch_size: int = 64) -> Dict: |
| """Perform forward/backward pass and capture activation/gradient stats""" |
| hook = ActivationGradientHook() |
| hook.register_hooks(model) |
| |
| |
| X_test = torch.empty(batch_size, dim).uniform_(-1, 1) |
| Y_test = X_test.clone() |
| |
| |
| model.zero_grad() |
| output = model(X_test) |
| loss = nn.MSELoss()(output, Y_test) |
| |
| |
| loss.backward() |
| |
| |
| act_means, act_stds = hook.get_activation_stats() |
| grad_norms = hook.get_gradient_norms() |
| |
| hook.remove_hooks() |
| |
| return { |
| 'activation_means': act_means, |
| 'activation_stds': act_stds, |
| 'gradient_norms': grad_norms, |
| 'final_loss': loss.item() |
| } |
|
|
|
|
| def plot_training_loss(plain_losses: List[float], res_losses: List[float], save_path: str): |
| """Plot training loss curves for both models""" |
| plt.figure(figsize=(10, 6)) |
| steps = range(len(plain_losses)) |
| |
| plt.plot(steps, plain_losses, label='PlainMLP', color='red', alpha=0.8) |
| plt.plot(steps, res_losses, label='ResMLP', color='blue', alpha=0.8) |
| |
| plt.xlabel('Training Steps', fontsize=12) |
| plt.ylabel('MSE Loss', fontsize=12) |
| plt.title('Training Loss: PlainMLP vs ResMLP on Identity Task', fontsize=14) |
| plt.legend(fontsize=11) |
| plt.grid(True, alpha=0.3) |
| plt.yscale('log') |
| |
| plt.tight_layout() |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| plt.close() |
| print(f"[Plot] Saved training loss plot to {save_path}") |
|
|
|
|
| def plot_gradient_magnitudes(plain_grads: List[float], res_grads: List[float], save_path: str): |
| """Plot gradient magnitude vs layer depth""" |
| plt.figure(figsize=(10, 6)) |
| layers = range(1, len(plain_grads) + 1) |
| |
| plt.plot(layers, plain_grads, 'o-', label='PlainMLP', color='red', markersize=6) |
| plt.plot(layers, res_grads, 's-', label='ResMLP', color='blue', markersize=6) |
| |
| plt.xlabel('Layer Depth', fontsize=12) |
| plt.ylabel('Gradient L2 Norm', fontsize=12) |
| plt.title('Gradient Magnitude vs Layer Depth (After Training)', fontsize=14) |
| plt.legend(fontsize=11) |
| plt.grid(True, alpha=0.3) |
| plt.yscale('log') |
| |
| plt.tight_layout() |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| plt.close() |
| print(f"[Plot] Saved gradient magnitude plot to {save_path}") |
|
|
|
|
| def plot_activation_means(plain_means: List[float], res_means: List[float], save_path: str): |
| """Plot activation mean vs layer depth""" |
| plt.figure(figsize=(10, 6)) |
| layers = range(1, len(plain_means) + 1) |
| |
| plt.plot(layers, plain_means, 'o-', label='PlainMLP', color='red', markersize=6) |
| plt.plot(layers, res_means, 's-', label='ResMLP', color='blue', markersize=6) |
| |
| plt.xlabel('Layer Depth', fontsize=12) |
| plt.ylabel('Activation Mean', fontsize=12) |
| plt.title('Activation Mean vs Layer Depth (After Training)', fontsize=14) |
| plt.legend(fontsize=11) |
| plt.grid(True, alpha=0.3) |
| |
| plt.tight_layout() |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| plt.close() |
| print(f"[Plot] Saved activation mean plot to {save_path}") |
|
|
|
|
| def plot_activation_stds(plain_stds: List[float], res_stds: List[float], save_path: str): |
| """Plot activation std vs layer depth""" |
| plt.figure(figsize=(10, 6)) |
| layers = range(1, len(plain_stds) + 1) |
| |
| plt.plot(layers, plain_stds, 'o-', label='PlainMLP', color='red', markersize=6) |
| plt.plot(layers, res_stds, 's-', label='ResMLP', color='blue', markersize=6) |
| |
| plt.xlabel('Layer Depth', fontsize=12) |
| plt.ylabel('Activation Std', fontsize=12) |
| plt.title('Activation Standard Deviation vs Layer Depth (After Training)', fontsize=14) |
| plt.legend(fontsize=11) |
| plt.grid(True, alpha=0.3) |
| |
| plt.tight_layout() |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| plt.close() |
| print(f"[Plot] Saved activation std plot to {save_path}") |
|
|
|
|
| def main(): |
| print("=" * 60) |
| print("PlainMLP vs ResMLP: Distant Identity Task Experiment") |
| print("=" * 60) |
| |
| |
| print("\n[1] Generating synthetic identity data...") |
| X, Y = generate_identity_data(NUM_SAMPLES, HIDDEN_DIM) |
| print(f" Data shape: X={X.shape}, Y={Y.shape}") |
| print(f" X range: [{X.min():.3f}, {X.max():.3f}]") |
| |
| |
| print("\n[2] Initializing models...") |
| plain_mlp = PlainMLP(HIDDEN_DIM, NUM_LAYERS) |
| res_mlp = ResMLP(HIDDEN_DIM, NUM_LAYERS) |
| |
| plain_params = sum(p.numel() for p in plain_mlp.parameters()) |
| res_params = sum(p.numel() for p in res_mlp.parameters()) |
| print(f" PlainMLP parameters: {plain_params:,}") |
| print(f" ResMLP parameters: {res_params:,}") |
| |
| |
| print("\n[3] Training PlainMLP...") |
| plain_losses = train_model(plain_mlp, X, Y, TRAINING_STEPS, LEARNING_RATE, BATCH_SIZE) |
| print(f" Final loss: {plain_losses[-1]:.6f}") |
| |
| |
| print("\n[4] Training ResMLP...") |
| res_losses = train_model(res_mlp, X, Y, TRAINING_STEPS, LEARNING_RATE, BATCH_SIZE) |
| print(f" Final loss: {res_losses[-1]:.6f}") |
| |
| |
| print("\n[5] Analyzing final state of trained models...") |
| print(" Analyzing PlainMLP...") |
| plain_stats = analyze_final_state(plain_mlp, HIDDEN_DIM) |
| print(" Analyzing ResMLP...") |
| res_stats = analyze_final_state(res_mlp, HIDDEN_DIM) |
| |
| |
| print("\n[6] Analysis Summary:") |
| print(f" PlainMLP - Final Loss: {plain_stats['final_loss']:.6f}") |
| print(f" ResMLP - Final Loss: {res_stats['final_loss']:.6f}") |
| print(f" PlainMLP - Gradient norm range: [{min(plain_stats['gradient_norms']):.2e}, {max(plain_stats['gradient_norms']):.2e}]") |
| print(f" ResMLP - Gradient norm range: [{min(res_stats['gradient_norms']):.2e}, {max(res_stats['gradient_norms']):.2e}]") |
| |
| |
| print("\n[7] Generating plots...") |
| plot_training_loss(plain_losses, res_losses, 'plots/training_loss.png') |
| plot_gradient_magnitudes(plain_stats['gradient_norms'], res_stats['gradient_norms'], |
| 'plots/gradient_magnitude.png') |
| plot_activation_means(plain_stats['activation_means'], res_stats['activation_means'], |
| 'plots/activation_mean.png') |
| plot_activation_stds(plain_stats['activation_stds'], res_stats['activation_stds'], |
| 'plots/activation_std.png') |
| |
| |
| results = { |
| 'config': { |
| 'num_layers': NUM_LAYERS, |
| 'hidden_dim': HIDDEN_DIM, |
| 'num_samples': NUM_SAMPLES, |
| 'training_steps': TRAINING_STEPS, |
| 'learning_rate': LEARNING_RATE, |
| 'batch_size': BATCH_SIZE |
| }, |
| 'plain_mlp': { |
| 'final_loss': plain_losses[-1], |
| 'initial_loss': plain_losses[0], |
| 'gradient_norms': plain_stats['gradient_norms'], |
| 'activation_means': plain_stats['activation_means'], |
| 'activation_stds': plain_stats['activation_stds'] |
| }, |
| 'res_mlp': { |
| 'final_loss': res_losses[-1], |
| 'initial_loss': res_losses[0], |
| 'gradient_norms': res_stats['gradient_norms'], |
| 'activation_means': res_stats['activation_means'], |
| 'activation_stds': res_stats['activation_stds'] |
| } |
| } |
| |
| with open('results.json', 'w') as f: |
| json.dump(results, f, indent=2) |
| print("\n[8] Results saved to results.json") |
| |
| print("\n" + "=" * 60) |
| print("Experiment completed successfully!") |
| print("=" * 60) |
| |
| return results |
|
|
|
|
| if __name__ == "__main__": |
| results = main() |
|
|