| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import torch |
| |
|
| | def plot_losses(log_dir): |
| | """Plot training losses from TensorBoard logs""" |
| | |
| | pass |
| |
|
| | def save_checkpoint(model, optimizer, epoch, path): |
| | torch.save({ |
| | 'epoch': epoch, |
| | 'model_state_dict': model.state_dict(), |
| | 'optimizer_state_dict': optimizer.state_dict(), |
| | }, path) |
| |
|
| | def load_checkpoint(model, optimizer, path): |
| | checkpoint = torch.load(path) |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| | return checkpoint['epoch'] |
| |
|
| | def show_samples(samples): |
| | """Display generated samples""" |
| | plt.figure(figsize=(10, 10)) |
| | plt.imshow(np.transpose(samples.numpy(), (1, 2, 0))) |
| | plt.axis('off') |
| | plt.show() |