| import torch |
| from model import SmoothDiffusionUNet |
| from noise_scheduler import FrequencyAwareNoise |
| from config import Config |
| from torchvision.utils import save_image, make_grid |
| import numpy as np |
|
|
| def deterministic_sample(model, noise_scheduler, device, n_samples=4): |
| """Deterministic sampling - just do a few big denoising steps""" |
| config = Config() |
| model.eval() |
| |
| with torch.no_grad(): |
| |
| x = torch.randn(n_samples, 3, config.image_size, config.image_size, device=device) * 0.5 |
| |
| print(f"Starting simplified sampling for {n_samples} samples...") |
| |
| |
| timesteps = [400, 300, 200, 150, 100, 70, 50, 30, 20, 10, 5, 1] |
| |
| for i, t_val in enumerate(timesteps): |
| print(f"Step {i+1}/{len(timesteps)}, t={t_val}") |
| |
| t_tensor = torch.full((n_samples,), t_val, device=device, dtype=torch.long) |
| |
| |
| predicted_noise = model(x, t_tensor) |
| |
| |
| alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() |
| |
| |
| pred_x0 = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t) |
| pred_x0 = torch.clamp(pred_x0, -1, 1) |
| |
| |
| if i < len(timesteps) - 1: |
| |
| next_t = timesteps[i + 1] |
| alpha_bar_next = noise_scheduler.alpha_bars[next_t].item() |
| |
| |
| noise_scale = np.sqrt(1 - alpha_bar_next) |
| noise = torch.randn_like(x) * 0.1 |
| |
| x = np.sqrt(alpha_bar_next) * pred_x0 + noise_scale * noise |
| else: |
| |
| x = pred_x0 |
| |
| x = torch.clamp(x, -1.5, 1.5) |
| |
| if i % 3 == 0: |
| print(f" Current range: [{x.min():.3f}, {x.max():.3f}], std: {x.std():.3f}") |
| |
| |
| x = torch.clamp(x, -1, 1) |
| |
| print(f"Final samples:") |
| print(f" Range: [{x.min():.3f}, {x.max():.3f}]") |
| print(f" Mean: {x.mean():.3f}, Std: {x.std():.3f}") |
| |
| |
| x_display = torch.clamp((x + 1) / 2, 0, 1) |
| |
| |
| grid = make_grid(x_display, nrow=2, normalize=False, pad_value=1.0) |
| save_image(grid, "simplified_samples.png") |
| print(f"Samples saved to simplified_samples.png") |
| |
| return x, grid |
|
|
| def progressive_sample(model, noise_scheduler, device, n_samples=4): |
| """Progressive denoising - start from less noise""" |
| config = Config() |
| model.eval() |
| |
| with torch.no_grad(): |
| |
| x = torch.randn(n_samples, 3, config.image_size, config.image_size, device=device) * 0.3 |
| |
| print(f"Starting progressive denoising for {n_samples} samples...") |
| |
| |
| start_t = 200 |
| |
| for step, t in enumerate(reversed(range(0, start_t))): |
| if step % 50 == 0: |
| print(f"Denoising step {step}/{start_t}, t={t}") |
| |
| t_tensor = torch.full((n_samples,), t, device=device, dtype=torch.long) |
| |
| |
| predicted_noise = model(x, t_tensor) |
| |
| |
| alpha_t = noise_scheduler.alphas[t].item() |
| alpha_bar_t = noise_scheduler.alpha_bars[t].item() |
| beta_t = noise_scheduler.betas[t].item() |
| |
| if t > 0: |
| alpha_bar_prev = noise_scheduler.alpha_bars[t-1].item() |
| |
| |
| pred_x0 = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t) |
| pred_x0 = torch.clamp(pred_x0, -1, 1) |
| |
| |
| coeff1 = np.sqrt(alpha_t) * (1 - alpha_bar_prev) / (1 - alpha_bar_t) |
| coeff2 = np.sqrt(alpha_bar_prev) * beta_t / (1 - alpha_bar_t) |
| mean = coeff1 * x + coeff2 * pred_x0 |
| |
| |
| if t > 1: |
| posterior_variance = beta_t * (1 - alpha_bar_prev) / (1 - alpha_bar_t) |
| noise = torch.randn_like(x) |
| |
| x = mean + np.sqrt(posterior_variance) * noise * 0.5 |
| else: |
| x = mean |
| else: |
| x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t) |
| |
| |
| x = torch.clamp(x, -1.2, 1.2) |
| |
| x = torch.clamp(x, -1, 1) |
| |
| print(f"Progressive samples:") |
| print(f" Range: [{x.min():.3f}, {x.max():.3f}]") |
| print(f" Mean: {x.mean():.3f}, Std: {x.std():.3f}") |
| |
| x_display = torch.clamp((x + 1) / 2, 0, 1) |
| grid = make_grid(x_display, nrow=2, normalize=False, pad_value=1.0) |
| save_image(grid, "progressive_samples.png") |
| print(f"Samples saved to progressive_samples.png") |
| |
| return x, grid |
|
|
| def main(): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| |
| checkpoint = torch.load('model_final.pth', map_location=device) |
| config = Config() |
| |
| model = SmoothDiffusionUNet(config).to(device) |
| noise_scheduler = FrequencyAwareNoise(config) |
| model.load_state_dict(checkpoint) |
| |
| print("=== TRYING DETERMINISTIC SAMPLING ===") |
| deterministic_sample(model, noise_scheduler, device, n_samples=4) |
| |
| print("\n=== TRYING PROGRESSIVE SAMPLING ===") |
| progressive_sample(model, noise_scheduler, device, n_samples=4) |
|
|
| if __name__ == "__main__": |
| main() |
|
|