| import torch |
| import torchvision |
| from torchvision.utils import save_image, make_grid |
| import os |
| import argparse |
| from datetime import datetime |
| from config import Config |
| from model import SmoothDiffusionUNet |
| from noise_scheduler_simple import FrequencyAwareNoise |
| from sample_simple import simple_sample |
|
|
| def load_model(checkpoint_path, device): |
| """Load model from checkpoint""" |
| print(f"Loading model from: {checkpoint_path}") |
| |
| |
| checkpoint = torch.load(checkpoint_path, map_location=device) |
| |
| |
| if 'config' in checkpoint: |
| config = checkpoint['config'] |
| else: |
| config = Config() |
| |
| model = SmoothDiffusionUNet(config).to(device) |
| noise_scheduler = FrequencyAwareNoise(config) |
| |
| |
| if 'model_state_dict' in checkpoint: |
| model.load_state_dict(checkpoint['model_state_dict']) |
| epoch = checkpoint.get('epoch', 'unknown') |
| loss = checkpoint.get('loss', 'unknown') |
| print(f"Loaded model from epoch {epoch}, loss: {loss}") |
| else: |
| |
| model.load_state_dict(checkpoint) |
| print("Loaded model state dict") |
| |
| return model, noise_scheduler, config |
|
|
| def test_checkpoint(checkpoint_path, device, n_samples=16): |
| """Test a single checkpoint with working sampler""" |
| model, noise_scheduler, config = load_model(checkpoint_path, device) |
| |
| |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| save_path = f"test_samples_simple_{timestamp}.png" |
| |
| print(f"Testing checkpoint with {n_samples} samples...") |
| samples, grid = simple_sample(model, noise_scheduler, device, n_samples=n_samples) |
| |
| |
| save_image(grid, save_path, normalize=False) |
| print(f"Samples saved to: {save_path}") |
| |
| return samples, grid |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Test trained diffusion model (simple version)') |
| parser.add_argument('--checkpoint', type=str, required=True, help='Path to checkpoint file') |
| parser.add_argument('--n_samples', type=int, default=16, help='Number of samples to generate') |
| parser.add_argument('--device', type=str, default='auto', help='Device to use (cuda/cpu/auto)') |
| |
| args = parser.parse_args() |
| |
| |
| if args.device == 'auto': |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| else: |
| device = torch.device(args.device) |
| |
| print(f"Using device: {device}") |
| |
| |
| print("=== Testing Checkpoint with Simple DDPM ===") |
| test_checkpoint(args.checkpoint, device, args.n_samples) |
|
|
| if __name__ == "__main__": |
| main() |
|
|