Code updates
Browse files- inference_brain2vec_PCA.py +222 -0
- model.py +0 -115
- requirements.txt +6 -3
- brain2vec_PCA.py → train_brain2vec_PCA.py +145 -88
inference_brain2vec_PCA.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
inference_brain2vec_PCA.py
|
| 5 |
+
|
| 6 |
+
Loads a pre-trained PCA-based Brain2Vec model (saved with joblib) and performs
|
| 7 |
+
inference on one or more input images. Produces embeddings (and optional
|
| 8 |
+
reconstructions) for each image.
|
| 9 |
+
|
| 10 |
+
Example usage:
|
| 11 |
+
|
| 12 |
+
python inference_brain2vec_PCA.py \
|
| 13 |
+
--pca_model /path/to/pca_model.joblib \
|
| 14 |
+
--input_images /path/to/img1.nii.gz /path/to/img2.nii.gz \
|
| 15 |
+
--output_dir /path/to/out
|
| 16 |
+
|
| 17 |
+
Or, if you have a CSV with image paths:
|
| 18 |
+
|
| 19 |
+
python inference_brain2vec_PCA.py \
|
| 20 |
+
--pca_model /path/to/pca_model.joblib \
|
| 21 |
+
--csv_input /path/to/images.csv \
|
| 22 |
+
--output_dir /path/to/out
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
import argparse
|
| 27 |
+
import numpy as np
|
| 28 |
+
import torch
|
| 29 |
+
import torch.nn as nn
|
| 30 |
+
from joblib import load
|
| 31 |
+
import pandas as pd
|
| 32 |
+
|
| 33 |
+
from monai.transforms import (
|
| 34 |
+
Compose,
|
| 35 |
+
CopyItemsD,
|
| 36 |
+
LoadImageD,
|
| 37 |
+
EnsureChannelFirstD,
|
| 38 |
+
SpacingD,
|
| 39 |
+
ResizeWithPadOrCropD,
|
| 40 |
+
ScaleIntensityD,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# Global constants
|
| 44 |
+
RESOLUTION = 2
|
| 45 |
+
INPUT_SHAPE_AE = (80, 96, 80)
|
| 46 |
+
FLATTENED_DIM = INPUT_SHAPE_AE[0] * INPUT_SHAPE_AE[1] * INPUT_SHAPE_AE[2]
|
| 47 |
+
|
| 48 |
+
# Reusable MONAI pipeline for preprocessing
|
| 49 |
+
transforms_fn = Compose([
|
| 50 |
+
CopyItemsD(keys={'image_path'}, names=['image']),
|
| 51 |
+
LoadImageD(image_only=True, keys=['image']),
|
| 52 |
+
EnsureChannelFirstD(keys=['image']),
|
| 53 |
+
SpacingD(pixdim=RESOLUTION, keys=['image']),
|
| 54 |
+
ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
|
| 55 |
+
ScaleIntensityD(minv=0, maxv=1, keys=['image']),
|
| 56 |
+
])
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def preprocess_mri(image_path: str) -> torch.Tensor:
|
| 60 |
+
"""
|
| 61 |
+
Preprocess an MRI using MONAI transforms to produce
|
| 62 |
+
a 5D Torch tensor: (batch=1, channel=1, D, H, W).
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
image_path (str): Path to the MRI (e.g., .nii.gz file).
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
torch.Tensor: Preprocessed 5D tensor of shape (1, 1, D, H, W).
|
| 69 |
+
"""
|
| 70 |
+
data_dict = {"image_path": image_path}
|
| 71 |
+
output_dict = transforms_fn(data_dict)
|
| 72 |
+
# shape => (1, D, H, W)
|
| 73 |
+
image_tensor = output_dict["image"].unsqueeze(0) # => (1, 1, D, H, W)
|
| 74 |
+
return image_tensor.float()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class PCABrain2vec(nn.Module):
|
| 78 |
+
"""
|
| 79 |
+
A PCA-based 'autoencoder' that mimics a typical VAE interface:
|
| 80 |
+
- from_pretrained(...) to load a PCA model from disk
|
| 81 |
+
- forward(...) returns (reconstruction, embedding, None)
|
| 82 |
+
|
| 83 |
+
Steps:
|
| 84 |
+
1. Flatten the input volume (N, 1, D, H, W) => (N, 614400).
|
| 85 |
+
2. Transform -> embeddings => shape (N, n_components).
|
| 86 |
+
3. Inverse transform -> recon => shape (N, 614400).
|
| 87 |
+
4. Reshape => (N, 1, D, H, W).
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(self, pca_model=None):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.pca_model = pca_model
|
| 93 |
+
|
| 94 |
+
def forward(self, x: torch.Tensor):
|
| 95 |
+
"""
|
| 96 |
+
Perform a forward pass of the PCA-based "autoencoder".
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
x (torch.Tensor): Input of shape (N, 1, D, H, W).
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
tuple(torch.Tensor, torch.Tensor, None):
|
| 103 |
+
- reconstruction: (N, 1, D, H, W)
|
| 104 |
+
- embedding: (N, n_components)
|
| 105 |
+
- None (to align with the typical VAE interface).
|
| 106 |
+
"""
|
| 107 |
+
n_samples = x.shape[0]
|
| 108 |
+
x_cpu = x.detach().cpu().numpy() # (N, 1, D, H, W)
|
| 109 |
+
x_flat = x_cpu.reshape(n_samples, -1) # => (N, FLATTENED_DIM)
|
| 110 |
+
|
| 111 |
+
# PCA transform => embeddings shape (N, n_components)
|
| 112 |
+
embedding_np = self.pca_model.transform(x_flat)
|
| 113 |
+
|
| 114 |
+
# PCA inverse_transform => recon shape (N, FLATTENED_DIM)
|
| 115 |
+
recon_np = self.pca_model.inverse_transform(embedding_np)
|
| 116 |
+
recon_np = recon_np.reshape(n_samples, 1, *INPUT_SHAPE_AE)
|
| 117 |
+
|
| 118 |
+
# Convert back to torch
|
| 119 |
+
reconstruction_torch = torch.from_numpy(recon_np).float()
|
| 120 |
+
embedding_torch = torch.from_numpy(embedding_np).float()
|
| 121 |
+
return reconstruction_torch, embedding_torch, None
|
| 122 |
+
|
| 123 |
+
@staticmethod
|
| 124 |
+
def from_pretrained(pca_path: str) -> "PCABrain2vec":
|
| 125 |
+
"""
|
| 126 |
+
Load a pre-trained PCA model (pickled or joblib) from disk.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
pca_path (str): File path to the PCA model.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
PCABrain2vec: An instance wrapping the loaded PCA model.
|
| 133 |
+
"""
|
| 134 |
+
if not os.path.exists(pca_path):
|
| 135 |
+
raise FileNotFoundError(f"Could not find PCA model at {pca_path}")
|
| 136 |
+
|
| 137 |
+
pca_model = load(pca_path)
|
| 138 |
+
return PCABrain2vec(pca_model=pca_model)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def main() -> None:
|
| 142 |
+
"""
|
| 143 |
+
Main function to parse command-line arguments and run inference
|
| 144 |
+
with a pre-trained PCA Brain2Vec model.
|
| 145 |
+
"""
|
| 146 |
+
parser = argparse.ArgumentParser(
|
| 147 |
+
description="PCA-based Brain2Vec Inference Script"
|
| 148 |
+
)
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--pca_model", type=str, required=True,
|
| 151 |
+
help="Path to the saved PCA model (.joblib)."
|
| 152 |
+
)
|
| 153 |
+
parser.add_argument(
|
| 154 |
+
"--output_dir", type=str, default="./pca_inference_outputs",
|
| 155 |
+
help="Directory to save embeddings/reconstructions."
|
| 156 |
+
)
|
| 157 |
+
# Two ways to supply images: multiple files or a CSV
|
| 158 |
+
parser.add_argument(
|
| 159 |
+
"--input_images", type=str, nargs="*",
|
| 160 |
+
help="One or more image paths for inference."
|
| 161 |
+
)
|
| 162 |
+
parser.add_argument(
|
| 163 |
+
"--csv_input", type=str, default=None,
|
| 164 |
+
help="Path to a CSV containing column 'image_path'."
|
| 165 |
+
)
|
| 166 |
+
args = parser.parse_args()
|
| 167 |
+
|
| 168 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 169 |
+
|
| 170 |
+
# Build the PCA model
|
| 171 |
+
pca_brain2vec = PCABrain2vec.from_pretrained(args.pca_model)
|
| 172 |
+
pca_brain2vec.eval()
|
| 173 |
+
|
| 174 |
+
# Gather image paths
|
| 175 |
+
if args.csv_input:
|
| 176 |
+
df = pd.read_csv(args.csv_input)
|
| 177 |
+
if "image_path" not in df.columns:
|
| 178 |
+
raise ValueError("CSV must contain a column named 'image_path'.")
|
| 179 |
+
image_paths = df["image_path"].tolist()
|
| 180 |
+
else:
|
| 181 |
+
if not args.input_images:
|
| 182 |
+
raise ValueError(
|
| 183 |
+
"Must provide either --csv_input or --input_images."
|
| 184 |
+
)
|
| 185 |
+
image_paths = args.input_images
|
| 186 |
+
|
| 187 |
+
# Inference loop
|
| 188 |
+
all_embeddings = []
|
| 189 |
+
for i, img_path in enumerate(image_paths):
|
| 190 |
+
if not os.path.exists(img_path):
|
| 191 |
+
raise FileNotFoundError(f"Image not found: {img_path}")
|
| 192 |
+
|
| 193 |
+
# Preprocess
|
| 194 |
+
img_tensor = preprocess_mri(img_path)
|
| 195 |
+
|
| 196 |
+
# Forward pass
|
| 197 |
+
with torch.no_grad():
|
| 198 |
+
recon, embedding, _ = pca_brain2vec(img_tensor)
|
| 199 |
+
|
| 200 |
+
# Convert to CPU numpy
|
| 201 |
+
embedding_np = embedding.detach().cpu().numpy()
|
| 202 |
+
recon_np = recon.detach().cpu().numpy()
|
| 203 |
+
|
| 204 |
+
# Save (one embedding row per image)
|
| 205 |
+
all_embeddings.append(embedding_np)
|
| 206 |
+
|
| 207 |
+
# Optionally save or visualize reconstructions
|
| 208 |
+
out_recon_path = os.path.join(
|
| 209 |
+
args.output_dir, f"reconstruction_{i}.npy"
|
| 210 |
+
)
|
| 211 |
+
np.save(out_recon_path, recon_np)
|
| 212 |
+
print(f"[INFO] Saved reconstruction to: {out_recon_path}")
|
| 213 |
+
|
| 214 |
+
# Save all embeddings stacked
|
| 215 |
+
stacked_embeddings = np.vstack(all_embeddings) # (N, n_components)
|
| 216 |
+
out_embed_path = os.path.join(args.output_dir, "all_pca_embeddings.npy")
|
| 217 |
+
np.save(out_embed_path, stacked_embeddings)
|
| 218 |
+
print(f"[INFO] Saved embeddings of shape {stacked_embeddings.shape} to: {out_embed_path}")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
if __name__ == "__main__":
|
| 222 |
+
main()
|
model.py
DELETED
|
@@ -1,115 +0,0 @@
|
|
| 1 |
-
# model.py
|
| 2 |
-
import os
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
import torch.nn as nn
|
| 6 |
-
|
| 7 |
-
from monai.transforms import (
|
| 8 |
-
Compose,
|
| 9 |
-
CopyItemsD,
|
| 10 |
-
LoadImageD,
|
| 11 |
-
EnsureChannelFirstD,
|
| 12 |
-
SpacingD,
|
| 13 |
-
ResizeWithPadOrCropD,
|
| 14 |
-
ScaleIntensityD,
|
| 15 |
-
)
|
| 16 |
-
|
| 17 |
-
# If you used joblib or pickle to save your PCA model:
|
| 18 |
-
from joblib import load # or "import pickle"
|
| 19 |
-
|
| 20 |
-
#################################################
|
| 21 |
-
# Constants
|
| 22 |
-
#################################################
|
| 23 |
-
RESOLUTION = 2
|
| 24 |
-
INPUT_SHAPE_AE = (80, 96, 80) # The typical shape from your pipelines
|
| 25 |
-
FLATTENED_DIM = INPUT_SHAPE_AE[0] * INPUT_SHAPE_AE[1] * INPUT_SHAPE_AE[2]
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
#################################################
|
| 29 |
-
# Define MONAI Transforms for Preprocessing
|
| 30 |
-
#################################################
|
| 31 |
-
transforms_fn = Compose([
|
| 32 |
-
CopyItemsD(keys={'image_path'}, names=['image']),
|
| 33 |
-
LoadImageD(image_only=True, keys=['image']),
|
| 34 |
-
EnsureChannelFirstD(keys=['image']),
|
| 35 |
-
SpacingD(pixdim=RESOLUTION, keys=['image']),
|
| 36 |
-
ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
|
| 37 |
-
ScaleIntensityD(minv=0, maxv=1, keys=['image']),
|
| 38 |
-
])
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def preprocess_mri(image_path: str) -> torch.Tensor:
|
| 42 |
-
"""
|
| 43 |
-
Preprocess an MRI using MONAI transforms to produce
|
| 44 |
-
a 5D Torch tensor: (batch=1, channel=1, D, H, W).
|
| 45 |
-
"""
|
| 46 |
-
data_dict = {"image_path": image_path}
|
| 47 |
-
output_dict = transforms_fn(data_dict)
|
| 48 |
-
# shape => (1, D, H, W)
|
| 49 |
-
image_tensor = output_dict["image"].unsqueeze(0) # => (batch=1, channel=1, D, H, W)
|
| 50 |
-
return image_tensor.float() # typically float32
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
#################################################
|
| 54 |
-
# PCA "Autoencoder" Wrapper
|
| 55 |
-
#################################################
|
| 56 |
-
class PCABrain2vec(nn.Module):
|
| 57 |
-
"""
|
| 58 |
-
A PCA-based 'autoencoder' that mimics the old interface:
|
| 59 |
-
- from_pretrained(...) to load a PCA model from disk
|
| 60 |
-
- forward(...) returns (reconstruction, embedding, None)
|
| 61 |
-
|
| 62 |
-
Under the hood, it:
|
| 63 |
-
- takes in a torch tensor shape (N, 1, D, H, W)
|
| 64 |
-
- flattens it (N, 614400)
|
| 65 |
-
- uses PCA's transform(...) to get embeddings => shape (N, n_components)
|
| 66 |
-
- uses inverse_transform(...) to get reconstructions => shape (N, 614400)
|
| 67 |
-
- reshapes back to (N, 1, D, H, W)
|
| 68 |
-
"""
|
| 69 |
-
|
| 70 |
-
def __init__(self, pca_model=None):
|
| 71 |
-
super().__init__()
|
| 72 |
-
# We'll store the fitted PCA model (from scikit-learn)
|
| 73 |
-
self.pca_model = pca_model # e.g., an instance of IncrementalPCA or PCA
|
| 74 |
-
|
| 75 |
-
def forward(self, x: torch.Tensor):
|
| 76 |
-
"""
|
| 77 |
-
Returns (reconstruction, embedding, None).
|
| 78 |
-
|
| 79 |
-
1) Convert x => numpy array => flatten => (N, 614400)
|
| 80 |
-
2) embedding = pca_model.transform(flat_x)
|
| 81 |
-
3) reconstruction_np = pca_model.inverse_transform(embedding)
|
| 82 |
-
4) reshape => (N, 1, 80, 96, 80)
|
| 83 |
-
5) convert to torch => return (recon, embed, None)
|
| 84 |
-
"""
|
| 85 |
-
# Expect x shape => (N, 1, D, H, W) => flatten to (N, D*H*W)
|
| 86 |
-
n_samples = x.shape[0]
|
| 87 |
-
# Convert to CPU np
|
| 88 |
-
x_cpu = x.detach().cpu().numpy() # shape: (N, 1, D, H, W)
|
| 89 |
-
x_flat = x_cpu.reshape(n_samples, -1) # shape: (N, 614400)
|
| 90 |
-
|
| 91 |
-
# PCA transform => embeddings shape (N, n_components)
|
| 92 |
-
embedding_np = self.pca_model.transform(x_flat)
|
| 93 |
-
|
| 94 |
-
# PCA inverse_transform => recon shape (N, 614400)
|
| 95 |
-
recon_np = self.pca_model.inverse_transform(embedding_np)
|
| 96 |
-
# Reshape back => (N, 1, 80, 96, 80)
|
| 97 |
-
recon_np = recon_np.reshape(n_samples, 1, *INPUT_SHAPE_AE)
|
| 98 |
-
|
| 99 |
-
# Convert back to torch
|
| 100 |
-
reconstruction_torch = torch.from_numpy(recon_np).float()
|
| 101 |
-
embedding_torch = torch.from_numpy(embedding_np).float()
|
| 102 |
-
return reconstruction_torch, embedding_torch, None
|
| 103 |
-
|
| 104 |
-
@staticmethod
|
| 105 |
-
def from_pretrained(pca_path: str):
|
| 106 |
-
"""
|
| 107 |
-
Load a pre-trained PCA model (pickled or joblib).
|
| 108 |
-
Returns an instance of PCABrain2vec with that model.
|
| 109 |
-
"""
|
| 110 |
-
if not os.path.exists(pca_path):
|
| 111 |
-
raise FileNotFoundError(f"Could not find PCA model at {pca_path}")
|
| 112 |
-
# Example: pca_model = pickle.load(open(pca_path, 'rb'))
|
| 113 |
-
# or use joblib:
|
| 114 |
-
pca_model = load(pca_path)
|
| 115 |
-
return PCABrain2vec(pca_model=pca_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,12 +1,15 @@
|
|
| 1 |
# requirements.txt
|
| 2 |
|
| 3 |
-
# PyTorch (CUDA or CPU version).
|
| 4 |
torch>=1.12
|
| 5 |
|
| 6 |
-
# MONAI
|
| 7 |
-
monai-weekly
|
| 8 |
monai-generative
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
# For perceptual losses in MONAI's generative module.
|
| 11 |
lpips
|
| 12 |
|
|
|
|
| 1 |
# requirements.txt
|
| 2 |
|
| 3 |
+
# PyTorch (CUDA or CPU version).
|
| 4 |
torch>=1.12
|
| 5 |
|
| 6 |
+
# Install MONAI Generative first
|
|
|
|
| 7 |
monai-generative
|
| 8 |
|
| 9 |
+
# Now force reinstall MONAI Weekly so its (newer) MONAI version takes precedence
|
| 10 |
+
--force-reinstall
|
| 11 |
+
monai-weekly
|
| 12 |
+
|
| 13 |
# For perceptual losses in MONAI's generative module.
|
| 14 |
lpips
|
| 15 |
|
brain2vec_PCA.py → train_brain2vec_PCA.py
RENAMED
|
@@ -1,101 +1,115 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
|
| 3 |
"""
|
| 4 |
-
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
|
| 10 |
Example usage:
|
| 11 |
-
python
|
| 12 |
--inputs_csv /path/to/inputs.csv \
|
| 13 |
--output_dir ./pca_outputs \
|
| 14 |
--pca_type standard \
|
| 15 |
-
--n_components
|
| 16 |
"""
|
| 17 |
|
| 18 |
import os
|
| 19 |
import argparse
|
| 20 |
import numpy as np
|
| 21 |
import pandas as pd
|
| 22 |
-
|
| 23 |
import torch
|
| 24 |
from torch.utils.data import DataLoader
|
| 25 |
-
|
| 26 |
from monai import transforms
|
| 27 |
from monai.data import Dataset, PersistentDataset
|
| 28 |
-
|
| 29 |
-
# We'll import both PCA classes, and decide which to use based on CLI arg.
|
| 30 |
from sklearn.decomposition import PCA, IncrementalPCA
|
|
|
|
| 31 |
|
| 32 |
-
|
| 33 |
-
###################################################################
|
| 34 |
-
# Constants for your typical config
|
| 35 |
-
###################################################################
|
| 36 |
RESOLUTION = 2
|
|
|
|
|
|
|
| 37 |
INPUT_SHAPE_AE = (80, 96, 80)
|
|
|
|
| 38 |
DEFAULT_N_COMPONENTS = 1200
|
| 39 |
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
| 45 |
"""
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
"""
|
|
|
|
| 49 |
if cache_dir and cache_dir.strip():
|
| 50 |
os.makedirs(cache_dir, exist_ok=True)
|
| 51 |
-
dataset = PersistentDataset(
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
| 54 |
else:
|
| 55 |
-
dataset = Dataset(data=
|
| 56 |
-
transform=transforms_fn)
|
| 57 |
return dataset
|
| 58 |
|
| 59 |
|
| 60 |
-
###################################################################
|
| 61 |
-
# PCAAutoencoder
|
| 62 |
-
###################################################################
|
| 63 |
class PCAAutoencoder:
|
| 64 |
"""
|
| 65 |
A PCA 'autoencoder' that can use either standard PCA or IncrementalPCA:
|
| 66 |
- fit(X): trains the model
|
| 67 |
- transform(X): get embeddings
|
| 68 |
- inverse_transform(Z): reconstruct data from embeddings
|
| 69 |
-
- forward(X): returns (X_recon, Z)
|
| 70 |
-
|
| 71 |
-
If using standard PCA,
|
| 72 |
-
If using incremental PCA,
|
| 73 |
"""
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
"""
|
|
|
|
|
|
|
| 76 |
Args:
|
| 77 |
-
n_components (int):
|
| 78 |
-
batch_size (int):
|
| 79 |
-
pca_type (str): 'incremental' or 'standard'
|
| 80 |
"""
|
| 81 |
self.n_components = n_components
|
| 82 |
self.batch_size = batch_size
|
| 83 |
self.pca_type = pca_type.lower()
|
| 84 |
|
| 85 |
-
if self.pca_type == '
|
| 86 |
-
self.ipca = PCA(n_components=self.n_components, svd_solver='randomized')
|
| 87 |
-
else:
|
| 88 |
-
# default to incremental
|
| 89 |
self.ipca = IncrementalPCA(n_components=self.n_components)
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
-
def fit(self, X: np.ndarray):
|
| 92 |
"""
|
| 93 |
-
Fit the PCA model. If incremental, calls partial_fit in batches
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
| 96 |
"""
|
| 97 |
if self.pca_type == 'standard':
|
| 98 |
-
# Potentially large memory usage, so be sure your system can handle it.
|
| 99 |
self.ipca.fit(X)
|
| 100 |
else:
|
| 101 |
# IncrementalPCA
|
|
@@ -107,7 +121,12 @@ class PCAAutoencoder:
|
|
| 107 |
def transform(self, X: np.ndarray) -> np.ndarray:
|
| 108 |
"""
|
| 109 |
Project data into the PCA latent space in batches for memory efficiency.
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
"""
|
| 112 |
results = []
|
| 113 |
n_samples = X.shape[0]
|
|
@@ -120,7 +139,12 @@ class PCAAutoencoder:
|
|
| 120 |
def inverse_transform(self, Z: np.ndarray) -> np.ndarray:
|
| 121 |
"""
|
| 122 |
Reconstruct data from PCA latent space in batches.
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
"""
|
| 125 |
results = []
|
| 126 |
n_samples = Z.shape[0]
|
|
@@ -130,80 +154,113 @@ class PCAAutoencoder:
|
|
| 130 |
results.append(X_chunk)
|
| 131 |
return np.vstack(results)
|
| 132 |
|
| 133 |
-
def forward(self, X: np.ndarray) ->
|
| 134 |
"""
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
"""
|
| 137 |
Z = self.transform(X)
|
| 138 |
X_recon = self.inverse_transform(Z)
|
| 139 |
return X_recon, Z
|
| 140 |
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
|
|
|
| 146 |
"""
|
|
|
|
|
|
|
| 147 |
1) Reads CSV.
|
| 148 |
-
2) Filters rows if 'split' in columns => only keep
|
| 149 |
-
3) Applies transforms to each image, flattening them into a 1D vector
|
| 150 |
-
4) Returns a NumPy array X
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
"""
|
| 152 |
df = pd.read_csv(csv_path)
|
| 153 |
|
| 154 |
-
#
|
| 155 |
if 'split' in df.columns:
|
| 156 |
df = df[df['split'] == 'train']
|
| 157 |
-
# If there is no 'split' column, we assume the entire CSV is for training.
|
| 158 |
|
| 159 |
dataset = get_dataset_from_pd(df, transforms_fn, cache_dir)
|
| 160 |
loader = DataLoader(dataset, batch_size=1, num_workers=0)
|
| 161 |
|
| 162 |
-
# We'll store each flattened volume in a list, then stack
|
| 163 |
X_list = []
|
| 164 |
for batch in loader:
|
| 165 |
-
# batch["image"]
|
| 166 |
-
img = batch["image"].squeeze(0) # => (1, 80, 96, 80)
|
| 167 |
-
|
| 168 |
-
flattened = img_np.flatten() # => (614400,)
|
| 169 |
X_list.append(flattened)
|
| 170 |
|
| 171 |
-
if
|
| 172 |
-
raise ValueError(
|
|
|
|
|
|
|
| 173 |
|
| 174 |
X = np.vstack(X_list)
|
| 175 |
return X
|
| 176 |
|
| 177 |
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
parser
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
args = parser.parse_args()
|
| 197 |
|
| 198 |
os.makedirs(args.output_dir, exist_ok=True)
|
| 199 |
|
| 200 |
-
# define transforms as in brain2vec_linearAE.py
|
| 201 |
transforms_fn = transforms.Compose([
|
| 202 |
transforms.CopyItemsD(keys={'image_path'}, names=['image']),
|
| 203 |
transforms.LoadImageD(image_only=True, keys=['image']),
|
| 204 |
transforms.EnsureChannelFirstD(keys=['image']),
|
| 205 |
transforms.SpacingD(pixdim=RESOLUTION, keys=['image']),
|
| 206 |
-
transforms.ResizeWithPadOrCropD(
|
|
|
|
|
|
|
| 207 |
transforms.ScaleIntensityD(minv=0, maxv=1, keys=['image']),
|
| 208 |
])
|
| 209 |
|
|
@@ -225,10 +282,10 @@ def main():
|
|
| 225 |
|
| 226 |
# Get embeddings & reconstruction
|
| 227 |
X_recon, Z = model.forward(X)
|
| 228 |
-
print("Embeddings shape:", Z.shape)
|
| 229 |
-
print("Reconstruction shape:", X_recon.shape)
|
| 230 |
|
| 231 |
-
# Save
|
| 232 |
embeddings_path = os.path.join(args.output_dir, "pca_embeddings.npy")
|
| 233 |
recons_path = os.path.join(args.output_dir, "pca_reconstructions.npy")
|
| 234 |
np.save(embeddings_path, Z)
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
|
| 3 |
"""
|
| 4 |
+
train_brain2vec_PCA.py
|
| 5 |
|
| 6 |
+
A PCA-based "autoencoder" script for brain MRI data, with support for both
|
| 7 |
+
incremental PCA and standard PCA. Only scans labeled 'train' in the CSV
|
| 8 |
+
(split == 'train') will be used for fitting.
|
| 9 |
|
| 10 |
Example usage:
|
| 11 |
+
python train_brain2vec_PCA.py \
|
| 12 |
--inputs_csv /path/to/inputs.csv \
|
| 13 |
--output_dir ./pca_outputs \
|
| 14 |
--pca_type standard \
|
| 15 |
+
--n_components 1200
|
| 16 |
"""
|
| 17 |
|
| 18 |
import os
|
| 19 |
import argparse
|
| 20 |
import numpy as np
|
| 21 |
import pandas as pd
|
|
|
|
| 22 |
import torch
|
| 23 |
from torch.utils.data import DataLoader
|
|
|
|
| 24 |
from monai import transforms
|
| 25 |
from monai.data import Dataset, PersistentDataset
|
| 26 |
+
from monai.transforms.transform import Transform
|
|
|
|
| 27 |
from sklearn.decomposition import PCA, IncrementalPCA
|
| 28 |
+
from typing import Optional, Union, Tuple
|
| 29 |
|
| 30 |
+
# voxel resolution
|
|
|
|
|
|
|
|
|
|
| 31 |
RESOLUTION = 2
|
| 32 |
+
|
| 33 |
+
# cropped image dimensions after transform
|
| 34 |
INPUT_SHAPE_AE = (80, 96, 80)
|
| 35 |
+
|
| 36 |
DEFAULT_N_COMPONENTS = 1200
|
| 37 |
|
| 38 |
|
| 39 |
+
def get_dataset_from_pd(
|
| 40 |
+
df: pd.DataFrame,
|
| 41 |
+
transforms_fn: Transform,
|
| 42 |
+
cache_dir: Optional[str]
|
| 43 |
+
) -> Union[Dataset, PersistentDataset]:
|
| 44 |
"""
|
| 45 |
+
Create a MONAI Dataset or PersistentDataset from the given DataFrame.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
df (pd.DataFrame): DataFrame with at least 'image_path' column.
|
| 49 |
+
transforms_fn (Transform): MONAI transform pipeline.
|
| 50 |
+
cache_dir (Optional[str]): If provided, use PersistentDataset caching.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Dataset|PersistentDataset: A dataset for training or inference.
|
| 54 |
"""
|
| 55 |
+
data_dicts = df.to_dict(orient='records')
|
| 56 |
if cache_dir and cache_dir.strip():
|
| 57 |
os.makedirs(cache_dir, exist_ok=True)
|
| 58 |
+
dataset = PersistentDataset(
|
| 59 |
+
data=data_dicts,
|
| 60 |
+
transform=transforms_fn,
|
| 61 |
+
cache_dir=cache_dir
|
| 62 |
+
)
|
| 63 |
else:
|
| 64 |
+
dataset = Dataset(data=data_dicts, transform=transforms_fn)
|
|
|
|
| 65 |
return dataset
|
| 66 |
|
| 67 |
|
|
|
|
|
|
|
|
|
|
| 68 |
class PCAAutoencoder:
|
| 69 |
"""
|
| 70 |
A PCA 'autoencoder' that can use either standard PCA or IncrementalPCA:
|
| 71 |
- fit(X): trains the model
|
| 72 |
- transform(X): get embeddings
|
| 73 |
- inverse_transform(Z): reconstruct data from embeddings
|
| 74 |
+
- forward(X): returns (X_recon, Z).
|
| 75 |
+
|
| 76 |
+
If using standard PCA, a single call to .fit(X) is made.
|
| 77 |
+
If using incremental PCA, .partial_fit is called in batches.
|
| 78 |
"""
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
n_components: int = DEFAULT_N_COMPONENTS,
|
| 83 |
+
batch_size: int = 128,
|
| 84 |
+
pca_type: str = 'standard'
|
| 85 |
+
) -> None:
|
| 86 |
"""
|
| 87 |
+
Initialize the PCAAutoencoder.
|
| 88 |
+
|
| 89 |
Args:
|
| 90 |
+
n_components (int): Number of principal components to keep.
|
| 91 |
+
batch_size (int): Chunk size for partial_fit or chunked transform.
|
| 92 |
+
pca_type (str): Either 'incremental' or 'standard'.
|
| 93 |
"""
|
| 94 |
self.n_components = n_components
|
| 95 |
self.batch_size = batch_size
|
| 96 |
self.pca_type = pca_type.lower()
|
| 97 |
|
| 98 |
+
if self.pca_type == 'incremental':
|
|
|
|
|
|
|
|
|
|
| 99 |
self.ipca = IncrementalPCA(n_components=self.n_components)
|
| 100 |
+
else:
|
| 101 |
+
# Default to standard PCA
|
| 102 |
+
self.ipca = PCA(n_components=self.n_components, svd_solver='randomized')
|
| 103 |
|
| 104 |
+
def fit(self, X: np.ndarray) -> None:
|
| 105 |
"""
|
| 106 |
+
Fit the PCA model. If incremental PCA, calls partial_fit in batches;
|
| 107 |
+
otherwise calls .fit once on the entire data array.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
X (np.ndarray): Shape (n_samples, n_features).
|
| 111 |
"""
|
| 112 |
if self.pca_type == 'standard':
|
|
|
|
| 113 |
self.ipca.fit(X)
|
| 114 |
else:
|
| 115 |
# IncrementalPCA
|
|
|
|
| 121 |
def transform(self, X: np.ndarray) -> np.ndarray:
|
| 122 |
"""
|
| 123 |
Project data into the PCA latent space in batches for memory efficiency.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
X (np.ndarray): Shape (n_samples, n_features).
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
np.ndarray: Latent embeddings of shape (n_samples, n_components).
|
| 130 |
"""
|
| 131 |
results = []
|
| 132 |
n_samples = X.shape[0]
|
|
|
|
| 139 |
def inverse_transform(self, Z: np.ndarray) -> np.ndarray:
|
| 140 |
"""
|
| 141 |
Reconstruct data from PCA latent space in batches.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
Z (np.ndarray): Latent embeddings of shape (n_samples, n_components).
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
np.ndarray: Reconstructed data of shape (n_samples, n_features).
|
| 148 |
"""
|
| 149 |
results = []
|
| 150 |
n_samples = Z.shape[0]
|
|
|
|
| 154 |
results.append(X_chunk)
|
| 155 |
return np.vstack(results)
|
| 156 |
|
| 157 |
+
def forward(self, X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 158 |
"""
|
| 159 |
+
Mimic a linear AE's forward() returning (X_recon, Z).
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
X (np.ndarray): Original data of shape (n_samples, n_features).
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
tuple[np.ndarray, np.ndarray]: (X_recon, Z).
|
| 166 |
"""
|
| 167 |
Z = self.transform(X)
|
| 168 |
X_recon = self.inverse_transform(Z)
|
| 169 |
return X_recon, Z
|
| 170 |
|
| 171 |
|
| 172 |
+
def load_and_flatten_dataset(
|
| 173 |
+
csv_path: str,
|
| 174 |
+
cache_dir: str,
|
| 175 |
+
transforms_fn: Transform
|
| 176 |
+
) -> np.ndarray:
|
| 177 |
"""
|
| 178 |
+
Load and flatten MRI volumes from the provided CSV.
|
| 179 |
+
|
| 180 |
1) Reads CSV.
|
| 181 |
+
2) Filters rows if 'split' in columns => only keep rows with split == 'train'.
|
| 182 |
+
3) Applies transforms to each image, flattening them into a 1D vector.
|
| 183 |
+
4) Returns a NumPy array X of shape (n_samples, 614400) after flattening.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
csv_path (str): Path to a CSV containing at least 'image_path' column.
|
| 187 |
+
Optionally has a 'split' column.
|
| 188 |
+
cache_dir (str): Path to cache directory for MONAI PersistentDataset.
|
| 189 |
+
transforms_fn (Transform): MONAI transform pipeline.
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
np.ndarray: Flattened image data of shape (n_samples, 614400).
|
| 193 |
"""
|
| 194 |
df = pd.read_csv(csv_path)
|
| 195 |
|
| 196 |
+
# Keep only 'train' samples if split column exists
|
| 197 |
if 'split' in df.columns:
|
| 198 |
df = df[df['split'] == 'train']
|
|
|
|
| 199 |
|
| 200 |
dataset = get_dataset_from_pd(df, transforms_fn, cache_dir)
|
| 201 |
loader = DataLoader(dataset, batch_size=1, num_workers=0)
|
| 202 |
|
|
|
|
| 203 |
X_list = []
|
| 204 |
for batch in loader:
|
| 205 |
+
# batch["image"] => shape (1, 1, 80, 96, 80)
|
| 206 |
+
img = batch["image"].squeeze(0) # => shape (1, 80, 96, 80)
|
| 207 |
+
flattened = img.numpy().flatten() # => (614400,)
|
|
|
|
| 208 |
X_list.append(flattened)
|
| 209 |
|
| 210 |
+
if not X_list:
|
| 211 |
+
raise ValueError(
|
| 212 |
+
"No training samples found (split='train'). Check your CSV or 'split' values."
|
| 213 |
+
)
|
| 214 |
|
| 215 |
X = np.vstack(X_list)
|
| 216 |
return X
|
| 217 |
|
| 218 |
|
| 219 |
+
def main() -> None:
|
| 220 |
+
"""
|
| 221 |
+
Main function to parse command-line arguments and fit a PCA or IncrementalPCA model,
|
| 222 |
+
then save embeddings and reconstructions.
|
| 223 |
+
"""
|
| 224 |
+
parser = argparse.ArgumentParser(
|
| 225 |
+
description="PCA Autoencoder with MONAI transforms and 'split' filtering."
|
| 226 |
+
)
|
| 227 |
+
parser.add_argument(
|
| 228 |
+
"--inputs_csv", type=str, required=True,
|
| 229 |
+
help="Path to CSV with at least 'image_path' column and optional 'split' column."
|
| 230 |
+
)
|
| 231 |
+
parser.add_argument(
|
| 232 |
+
"--cache_dir", type=str, default="",
|
| 233 |
+
help="Cache directory for MONAI PersistentDataset (optional)."
|
| 234 |
+
)
|
| 235 |
+
parser.add_argument(
|
| 236 |
+
"--output_dir", type=str, default="./pca_outputs",
|
| 237 |
+
help="Where to save PCA model and embeddings."
|
| 238 |
+
)
|
| 239 |
+
parser.add_argument(
|
| 240 |
+
"--batch_size_ipca", type=int, default=128,
|
| 241 |
+
help="Batch size for partial_fit or chunked transform."
|
| 242 |
+
)
|
| 243 |
+
parser.add_argument(
|
| 244 |
+
"--n_components", type=int, default=1200,
|
| 245 |
+
help="Number of PCA components to keep."
|
| 246 |
+
)
|
| 247 |
+
parser.add_argument(
|
| 248 |
+
"--pca_type", type=str, default="incremental",
|
| 249 |
+
choices=["incremental", "standard"],
|
| 250 |
+
help="Which PCA algorithm to use: 'incremental' or 'standard'."
|
| 251 |
+
)
|
| 252 |
args = parser.parse_args()
|
| 253 |
|
| 254 |
os.makedirs(args.output_dir, exist_ok=True)
|
| 255 |
|
|
|
|
| 256 |
transforms_fn = transforms.Compose([
|
| 257 |
transforms.CopyItemsD(keys={'image_path'}, names=['image']),
|
| 258 |
transforms.LoadImageD(image_only=True, keys=['image']),
|
| 259 |
transforms.EnsureChannelFirstD(keys=['image']),
|
| 260 |
transforms.SpacingD(pixdim=RESOLUTION, keys=['image']),
|
| 261 |
+
transforms.ResizeWithPadOrCropD(
|
| 262 |
+
spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']
|
| 263 |
+
),
|
| 264 |
transforms.ScaleIntensityD(minv=0, maxv=1, keys=['image']),
|
| 265 |
])
|
| 266 |
|
|
|
|
| 282 |
|
| 283 |
# Get embeddings & reconstruction
|
| 284 |
X_recon, Z = model.forward(X)
|
| 285 |
+
print("Embeddings shape:", Z.shape)
|
| 286 |
+
print("Reconstruction shape:", X_recon.shape)
|
| 287 |
|
| 288 |
+
# Save embeddings and reconstructions
|
| 289 |
embeddings_path = os.path.join(args.output_dir, "pca_embeddings.npy")
|
| 290 |
recons_path = os.path.join(args.output_dir, "pca_reconstructions.npy")
|
| 291 |
np.save(embeddings_path, Z)
|