File size: 8,835 Bytes
0161e74 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 | import abc
from typing import Any, Literal
import jax
import jax.numpy as jnp
import numpy as np
from src.data_process._data import PredictionData, TrainingData, ValidationData
import pdb
import torch
__all__ = ["TrainSampler", "ValidationSampler", "PredictionSampler"]
class TrainSampler:
"""Data sampler for :class:`~cellflow.data.TrainingData`.
Parameters
----------
data
The training data.
batch_size
The batch size.
"""
def __init__(self, data: TrainingData, batch_size: int = 1024):
self._data = data
self._data_idcs = jnp.arange(data.cell_data.shape[0])
self.batch_size = batch_size
self.n_source_dists = data.n_controls
self.n_target_dists = data.n_perturbations
self.conditional_samplings = [
lambda key, i=i: jax.random.choice(key, self._data.control_to_perturbation[i])
for i in range(self.n_source_dists)
]
self.get_embeddings = lambda idx: {
pert_cov: jnp.expand_dims(arr[idx], 0) for pert_cov, arr in self._data.condition_data.items()
}
# @jax.jit
def _sample(rng: jax.Array) -> Any:
rng_1, rng_2, rng_3, rng_4 = jax.random.split(rng, 4)
source_dist_idx = jax.random.choice(rng_1, self.n_source_dists)
source_cells_mask = self._data.split_covariates_mask == source_dist_idx
src_cond_p = source_cells_mask / jnp.count_nonzero(source_cells_mask)
source_batch_idcs = jax.random.choice(rng_2, self._data_idcs, [self.batch_size], replace=True, p=src_cond_p)
source_batch = self._data.cell_data[source_batch_idcs]
target_dist_idx = jax.lax.switch(source_dist_idx, self.conditional_samplings, rng_3)
target_cells_mask = self._data.perturbation_covariates_mask == target_dist_idx
tgt_cond_p = target_cells_mask / jnp.count_nonzero(target_cells_mask)
target_batch_idcs = jax.random.choice(
rng_4,
self._data_idcs,
[self.batch_size],
replace=True,
p=tgt_cond_p,
)
target_batch = self._data.cell_data[target_batch_idcs]
if self._data.condition_data is None:
return {"src_cell_data": source_batch, "tgt_cell_data": target_batch}
condition_batch = self.get_embeddings(target_dist_idx)
return {
"src_cell_data": source_batch,
"tgt_cell_data": target_batch,
"condition": condition_batch,
"condition_id": self._data.condition_data_id[target_dist_idx],
"src_cell_id": self._data.cell_data_id[source_batch_idcs],
"tgt_cell_id": self._data.cell_data_id[target_batch_idcs]
}
self.sample = _sample
@property
def data(self) -> TrainingData:
"""The training data."""
return self._data
def precompute_latent_embedding(self, model,):
"""
compute latent representation
return
----
latent_embedding : torch.Tensor
"""
cell_data = self._data.cell_data
with torch.no_grad():
self._data.cell_embedding = model(cell_data)
class BaseValidSampler(abc.ABC):
@abc.abstractmethod
def sample(*args, **kwargs):
pass
def _get_key(self, cond_idx: int) -> tuple[str, ...]:
if len(self._data.perturbation_idx_to_id): # type: ignore[attr-defined]
return self._data.perturbation_idx_to_id[cond_idx] # type: ignore[attr-defined]
cov_combination = self._data.perturbation_idx_to_covariates[cond_idx] # type: ignore[attr-defined]
return tuple(cov_combination[i] for i in range(len(cov_combination)))
def _get_perturbation_to_control(self, data: ValidationData | PredictionData) -> dict[int, int]:
d = {}
for k, v in data.control_to_perturbation.items():
for el in v:
d[el] = k
return d
def _get_condition_data(self, cond_idx: int) -> jnp.ndarray:
return {k: v[[cond_idx], ...] for k, v in self._data.condition_data.items()} # type: ignore[attr-defined]
class ValidationSampler(BaseValidSampler):
"""Data sampler for :class:`~cellflow.data.ValidationData`.
Parameters
----------
val_data
The validation data.
seed
Random seed.
"""
def __init__(self, val_data: ValidationData, seed: int = 0) -> None:
self._data = val_data
self.perturbation_to_control = self._get_perturbation_to_control(val_data)
self.n_conditions_on_log_iteration = (
val_data.n_conditions_on_log_iteration
if val_data.n_conditions_on_log_iteration is not None
else val_data.n_perturbations
)
self.n_conditions_on_train_end = (
val_data.n_conditions_on_train_end
if val_data.n_conditions_on_train_end is not None
else val_data.n_perturbations
)
self.rng = np.random.default_rng(seed)
if self._data.condition_data is None:
raise NotImplementedError("Validation data must have condition data.")
def sample(self, mode: Literal["on_log_iteration", "on_train_end"]) -> Any:
"""Sample data for validation.
Parameters
----------
mode
Sampling mode. Either ``"on_log_iteration"`` or ``"on_train_end"``.
Returns
-------
Dictionary with source, condition, and target data from the validation data.
"""
import pdb; pdb.set_trace()
size = self.n_conditions_on_log_iteration if mode == "on_log_iteration" else self.n_conditions_on_train_end
condition_idcs = self.rng.choice(self._data.n_perturbations, size=(size,), replace=False)
source_idcs = [self.perturbation_to_control[cond_idx] for cond_idx in condition_idcs]
source_cells_mask = [self._data.split_covariates_mask == source_idx for source_idx in source_idcs]
source_cells = [self._data.cell_data[mask] for mask in source_cells_mask]
target_cells_mask = [cond_idx == self._data.perturbation_covariates_mask for cond_idx in condition_idcs]
target_cells = [self._data.cell_data[mask] for mask in target_cells_mask]
conditions = [self._get_condition_data(cond_idx) for cond_idx in condition_idcs]
cell_rep_dict = {}
cond_dict = {}
true_dict = {}
for i in range(len(condition_idcs)):
k = self._get_key(condition_idcs[i])
cell_rep_dict[k] = source_cells[i]
cond_dict[k] = conditions[i]
true_dict[k] = target_cells[i]
return {"source": cell_rep_dict, "condition": cond_dict, "target": true_dict}
@property
def data(self) -> ValidationData:
"""The validation data."""
return self._data
class PredictionSampler(BaseValidSampler):
"""Data sampler for :class:`~cellflow.data.PredictionData`.
Parameters
----------
pred_data
The prediction data.
"""
def __init__(self, pred_data: PredictionData) -> None:
self._data = pred_data
self.perturbation_to_control = self._get_perturbation_to_control(pred_data)
if self._data.condition_data is None:
raise NotImplementedError("Validation data must have condition data.")
def sample(self) -> Any:
"""Sample data for prediction.
Returns
-------
Dictionary with source and condition data from the prediction data.
"""
condition_idcs = range(self._data.n_perturbations)
source_idcs = [self.perturbation_to_control[cond_idx] for cond_idx in condition_idcs]
source_cells_mask = [self._data.split_covariates_mask == source_idx for source_idx in source_idcs]
source_cells = [self._data.cell_data[mask] for mask in source_cells_mask]
target_cells_mask = [cond_idx == self._data.perturbation_covariates_mask for cond_idx in condition_idcs]
target_cells = [self._data.cell_data[mask] for mask in target_cells_mask]
conditions = [self._get_condition_data(cond_idx) for cond_idx in condition_idcs]
cell_rep_dict = {}
cond_dict = {}
true_dict = {}
for i in range(len(condition_idcs)):
k = self._get_key(condition_idcs[i])
cell_rep_dict[k] = source_cells[i]
cond_dict[k] = conditions[i]
true_dict[k] = target_cells[i]
return {
"source": cell_rep_dict,
"target": true_dict,
"condition": cond_dict,
}
@property
def data(self) -> PredictionData:
"""The training data."""
return self._data
|