| import torch |
| import torchvision.transforms as T |
| import torchvision.datasets as datasets |
| from torch.utils.data import DataLoader |
| from torchvision.models import resnet18 |
| from sklearn.manifold import TSNE |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from model import SSLModel |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| model = SSLModel(resnet18(pretrained=False)).to(device) |
| saved_model_path = "models/saves/run2/ssl_checkpoint_epoch_15.pth" |
| checkpoint = torch.load(saved_model_path, map_location=device) |
| model.load_state_dict(checkpoint["model_state_dict"]) |
| model.eval() |
| print(f"Model loaded from {saved_model_path}") |
|
|
| transform = T.Compose([ |
| T.Resize(32), |
| T.ToTensor(), |
| T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
| ]) |
| dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True) |
| dataloader = DataLoader(dataset, batch_size=256, shuffle=False) |
|
|
| |
| embeddings = [] |
| labels = [] |
|
|
| print("Extracting embeddings...") |
| with torch.no_grad(): |
| for imgs, lbls in dataloader: |
| imgs = imgs.to(device) |
| z = model(imgs) |
| embeddings.append(z.cpu().numpy()) |
| labels.append(lbls.numpy()) |
|
|
| |
| embeddings = np.concatenate(embeddings, axis=0) |
| labels = np.concatenate(labels, axis=0) |
|
|
| |
| print("Reducing dimensionality...") |
| tsne = TSNE(n_components=2, random_state=42, init="pca", learning_rate="auto") |
| reduced_embeddings = tsne.fit_transform(embeddings) |
|
|
| |
| def plot_embeddings(embeddings, labels, class_names): |
| plt.figure(figsize=(10, 8)) |
| scatter = plt.scatter( |
| embeddings[:, 0], |
| embeddings[:, 1], |
| c=labels, |
| cmap="tab10", |
| alpha=0.7 |
| ) |
| legend = plt.legend( |
| handles=scatter.legend_elements()[0], |
| labels=class_names, |
| loc="upper right", |
| title="Classes" |
| ) |
| plt.title("t-SNE Visualization of SSL Embeddings") |
| plt.xlabel("Dimension 1") |
| plt.ylabel("Dimension 2") |
| plt.grid(True) |
| plt.show() |
|
|
| |
| class_names = dataset.classes |
|
|
| |
| plot_embeddings(reduced_embeddings, labels, class_names) |
|
|
|
|