| | """ |
| | Module: tokenization.py |
| | |
| | This module provides a tokenization pipeline for preprocessed single-cell RNA sequencing (scRNA-seq) data. |
| | It converts gene expression data stored in AnnData format into tokenized sequences that can |
| | be used for downstream machine learning tasks, such as masked language modeling or classification. |
| | |
| | Main Features: |
| | - Tokenizes gene expression data into integer tokens using a custom GeneTokenizer. |
| | - Supports additional biological annotations (e.g., disease, tissue, cell type, sex). |
| | - Handles both top-k and random gene selection for tokenization. |
| | - Configurable via JSON-based hyperparameters or TokenizationArgs objects. |
| | - Saves tokenized data in Hugging Face Dataset format for efficient processing. |
| | |
| | Dependencies: |
| | - anndata, numpy, torch, datasets, tqdm |
| | |
| | Usage: |
| | - Run this script as a standalone program with a configuration file specifying the hyperparameters. |
| | - Import the `tokenize` function and call it with the data path, metadata path, and tokenization arguments. |
| | """ |
| |
|
| | import gc |
| | import os |
| | import json |
| | import random |
| | import shutil |
| | from argparse import ArgumentParser |
| | from typing import Union |
| |
|
| | import anndata as ad |
| | import numpy as np |
| | import torch |
| | from datasets import Dataset, load_from_disk |
| | from tqdm import tqdm |
| |
|
| | from teddy.tokenizer.gene_tokenizer import GeneTokenizer |
| | from teddy.tokenizer.tokenization_args import TokenizationArgs |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def _bin_values(vals_list, tokenization_args, no_sorting=False): |
| | """ |
| | Bins expression values into specified bins, assigning bin 0 to non-expressed genes |
| | when `include_zero_genes` is True. |
| | |
| | no_sorting=False => "positional chunk" approach for topk-sorted arrays - provided data_processing is expected to be sorted through topk (input expression values). |
| | no_sorting=True => simple bucketize approach ignoring the topk order - provided data_processing is not sorted (labels). |
| | """ |
| | binned_vals = [] |
| | for vals in vals_list: |
| | if isinstance(vals, np.ndarray): |
| | vals = torch.tensor(vals) |
| |
|
| | vals_to_bin = vals |
| |
|
| | |
| | if not no_sorting: |
| | |
| | num_repetitions = max(1, len(vals_to_bin) // tokenization_args.bins) |
| | bin_pattern = torch.arange(0, tokenization_args.bins).unsqueeze(1).repeat(1, num_repetitions).flatten() |
| |
|
| | |
| | if len(bin_pattern) > len(vals_to_bin): |
| | bin_pattern = bin_pattern[-len(vals_to_bin) :] |
| | else: |
| | extra = len(vals_to_bin) - len(bin_pattern) |
| | if extra > 0: |
| | bin_pattern = torch.cat([torch.zeros(extra), bin_pattern]) |
| | bin_pattern = bin_pattern.flip(0) |
| |
|
| | binned_vals.append(bin_pattern) |
| | else: |
| | if len(vals_to_bin) > 0: |
| | bin_edges = torch.linspace(vals_to_bin.min(), vals_to_bin.max(), steps=tokenization_args.bins + 1) |
| | binned_non_zero_vals = torch.bucketize(vals_to_bin, bin_edges) |
| | binned_non_zero_vals = torch.clamp(binned_non_zero_vals, min=1) |
| | binned_tensor = binned_non_zero_vals.float() |
| | binned_vals.append(binned_tensor) |
| | else: |
| | binned_tensor = torch.zeros_like(vals_to_bin, dtype=torch.float) |
| | binned_vals.append(binned_tensor) |
| | return binned_vals |
| |
|
| |
|
| | def _rank_continuous(vals, tokenization_args): |
| | """ |
| | Ranks gene expression values in the range [-1, 1]. |
| | """ |
| | if isinstance(vals, np.ndarray): |
| | vals = torch.tensor(vals) |
| |
|
| | if len(vals) > 0: |
| | ranked_vals = torch.linspace(-1, 1, steps=len(vals)).flip(0) |
| | else: |
| | ranked_vals = vals |
| | return ranked_vals |
| |
|
| |
|
| | def _prepare_tokenizer_args(tokenization_args: Union[dict, TokenizationArgs]): |
| | """ |
| | Prepares and validates tokenization arguments, ensuring reproducibility |
| | by setting random seeds if specified. |
| | """ |
| | if isinstance(tokenization_args, dict): |
| | load_dir = tokenization_args["load_dir"] |
| | save_dir = tokenization_args["save_dir"] |
| | token_args_obj = TokenizationArgs(**tokenization_args) |
| | else: |
| | |
| | load_dir = tokenization_args.load_dir |
| | save_dir = tokenization_args.save_dir |
| | token_args_obj = tokenization_args |
| |
|
| | |
| | if token_args_obj.gene_seed is not None: |
| | random.seed(token_args_obj.gene_seed) |
| | np.random.seed(token_args_obj.gene_seed) |
| | torch.manual_seed(token_args_obj.gene_seed) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed_all(token_args_obj.gene_seed) |
| |
|
| | return token_args_obj, load_dir, save_dir |
| |
|
| |
|
| | def _check_genes_in_tokenizer(data: ad.AnnData, gene_id_column: str, tokenizer: GeneTokenizer): |
| | """ |
| | Checks if the genes in the dataset are present in the tokenizer's vocabulary. |
| | """ |
| | if gene_id_column == "index": |
| | gene_index = data.var.index |
| | else: |
| | gene_index = data.var[gene_id_column] |
| |
|
| | |
| | gene_in_vocab = np.where([g in tokenizer.vocab for g in gene_index])[0] |
| | coding_genes = gene_index[gene_in_vocab] |
| | ratio = len(gene_in_vocab) / len(data.var) |
| | if ratio < 0.1: |
| | raise OSError( |
| | f"Only {ratio:.2%} of gene IDs found in tokenizer vocab. " "Check gene_id_column or vocab mismatch." |
| | ) |
| | return gene_in_vocab, coding_genes, ratio |
| |
|
| |
|
| | def _build_batch_tensors(X_batch: torch.Tensor, token_array: torch.Tensor, token_args, data=None, obs_indices=None): |
| | """ |
| | Build topk or random subsets for each row in X_batch (batch_size x num_genes). |
| | Return gene_list, vals_list, labels_list. |
| | """ |
| | batch_size = X_batch.shape[0] |
| | seq_tokens = token_args.max_seq_len - 1 if token_args.add_cls else token_args.max_seq_len |
| |
|
| | |
| | if token_args.random_genes: |
| | random_indices = torch.stack([torch.randperm(X_batch.shape[1])[:seq_tokens] for _ in range(batch_size)]) |
| | random_vals = torch.gather(X_batch, 1, random_indices) |
| | top_vals, rel_indices = torch.topk( |
| | random_vals, k=min(seq_tokens, random_vals.shape[1]), largest=True, sorted=True |
| | ) |
| | |
| | top_indices = torch.gather(random_indices, 1, rel_indices) |
| | else: |
| | |
| | top_vals, top_indices = torch.topk(X_batch, k=min(seq_tokens, X_batch.shape[1]), largest=True, sorted=True) |
| |
|
| | gene_ids = token_array[top_indices] |
| |
|
| | |
| | if token_args.add_cls: |
| | cls_col = torch.tensor(token_args.cls_token_id).repeat(batch_size, 1) |
| | gene_ids = torch.cat([cls_col, gene_ids], dim=1) |
| | ones_col = torch.ones(batch_size, 1, dtype=top_vals.dtype) |
| | top_vals = torch.cat([ones_col, top_vals], dim=1) |
| |
|
| | labels_list = None |
| |
|
| | return gene_ids, top_vals, labels_list, None |
| |
|
| |
|
| | |
| | |
| | |
| | def tokenize(data_path: str, metadata_path: str, tokenization_args: Union[dict, TokenizationArgs]): |
| | """ |
| | Tokenizes gene expression data stored in AnnData format. |
| | |
| | Args: |
| | data_path (str): Path to the AnnData file containing preprocessed gene expression data. |
| | metadata_path (str): Path to the metadata file in JSON format. |
| | tokenization_args (Union[dict, TokenizationArgs]): Configuration for tokenization. |
| | """ |
| |
|
| | token_args, load_dir, save_dir = _prepare_tokenizer_args(tokenization_args) |
| |
|
| | |
| | tokenizer = GeneTokenizer.from_pretrained(token_args.tokenizer_name_or_path) |
| | if token_args.cls_token_id is None: |
| | token_args.cls_token_id = tokenizer.cls_token_id |
| |
|
| | |
| | data = ad.read_h5ad(data_path) |
| |
|
| | if "processed" not in data.layers: |
| | raise ValueError(f"Missing 'processed' layer in {data_path}") |
| |
|
| | |
| | gene_in_vocab, coding_genes, ratio = _check_genes_in_tokenizer(data, token_args.gene_id_column, tokenizer) |
| | print(f"{ratio:.2%} of genes found in tokenizer vocab") |
| |
|
| | |
| | token_array = torch.tensor(tokenizer.encode(coding_genes.tolist(), add_special_tokens=False)) |
| |
|
| | |
| | X_matrix = data.layers["processed"].toarray() |
| |
|
| | |
| | all_data = {"gene_ids": [], "values": []} |
| |
|
| | BATCH_SIZE = 512 |
| | n_obs = data.shape[0] |
| |
|
| | for start_idx in tqdm(range(0, n_obs, BATCH_SIZE), desc="Tokenizing in batches"): |
| | end_idx = min(start_idx + BATCH_SIZE, n_obs) |
| | obs_indices = np.arange(start_idx, end_idx) |
| |
|
| | X_batch = torch.tensor(X_matrix[obs_indices, :][:, gene_in_vocab], dtype=torch.float) |
| | gene_ids_batch, vals_batch, labels_batch, decoder_vals_batch = _build_batch_tensors( |
| | X_batch, |
| | token_array, |
| | token_args, |
| | data=None, |
| | obs_indices=None, |
| | ) |
| |
|
| | final_gene_list = [] |
| | final_vals_list = [] |
| | final_labels_list = [] |
| | if "decoder_values" in data.layers: |
| | final_decoder_vals_list = [] |
| |
|
| | |
| | |
| | for row_idx in range(len(gene_ids_batch)): |
| | g_row = gene_ids_batch[row_idx] |
| | v_row = vals_batch[row_idx] |
| |
|
| | if labels_batch is not None: |
| | lb_row = labels_batch[row_idx] |
| | else: |
| | lb_row = None |
| |
|
| | if decoder_vals_batch is not None: |
| | dec_v_row = decoder_vals_batch[row_idx] |
| | else: |
| | dec_v_row = None |
| |
|
| | if not token_args.include_zero_genes: |
| | nonzero_mask = v_row != 0 |
| | g_row = g_row[nonzero_mask] |
| | v_row = v_row[nonzero_mask] |
| | if lb_row is not None: |
| | lb_row = lb_row[nonzero_mask] |
| | if dec_v_row is not None: |
| | dec_v_row = dec_v_row[nonzero_mask] |
| |
|
| | final_gene_list.append(g_row) |
| | final_vals_list.append(v_row) |
| | final_labels_list.append(lb_row) |
| | if "decoder_values" in data.layers: |
| | final_decoder_vals_list.append(dec_v_row) |
| |
|
| | |
| | if token_args.bins and token_args.continuous_rank: |
| | raise ValueError("Should not use bins and continuous_rank simultaneously.") |
| |
|
| | if token_args.bins: |
| | |
| | |
| | final_vals_list = _bin_values(final_vals_list, token_args, no_sorting=False) |
| |
|
| | elif token_args.continuous_rank: |
| | for i, vals in enumerate(final_vals_list): |
| | final_vals_list[i] = _rank_continuous(vals, token_args) |
| |
|
| | |
| | for row_idx in range(len(final_gene_list)): |
| | all_data["gene_ids"].append(final_gene_list[row_idx].tolist()) |
| | all_data["values"].append(final_vals_list[row_idx].tolist()) |
| |
|
| | if token_args.label_column: |
| | all_data["labels"] = data.obs[token_args.label_column].cat.codes.values.tolist() |
| |
|
| | |
| | if token_args.bio_annotations: |
| | with open(token_args.disease_mapping) as f: |
| | disease_mapping = json.load(f) |
| | with open(token_args.tissue_mapping) as f: |
| | tissue_mapping = json.load(f) |
| | with open(token_args.cell_mapping) as f: |
| | cell_mapping = json.load(f) |
| | with open(token_args.sex_mapping) as f: |
| | sex_mapping = json.load(f) |
| |
|
| | if "disease" not in data.obs.columns: |
| | data.obs["disease"] = "normal" |
| | if "tissue" not in data.obs.columns: |
| | data.obs["tissue"] = "cultured cell" |
| | if "sex" not in data.obs.columns: |
| | data.obs["sex"] = "unknown" |
| | if "cell_type" not in data.obs.columns: |
| | data.obs["cell_type"] = "unknown" |
| |
|
| | mapped_diseases = [disease_mapping[k] for k in data.obs["disease"].tolist()] |
| | mapped_tissues = [tissue_mapping[k] for k in data.obs["tissue"].tolist()] |
| | mapped_cell_types = [cell_mapping[k] for k in data.obs["cell_type"].tolist()] |
| | mapped_sexes = [sex_mapping[k] for k in data.obs["sex"].tolist()] |
| |
|
| | all_data["disease"] = tokenizer.encode(mapped_diseases, add_special_tokens=False) |
| | all_data["tissue"] = tokenizer.encode(mapped_tissues, add_special_tokens=False) |
| | all_data["cell_type"] = tokenizer.encode(mapped_cell_types, add_special_tokens=False) |
| | all_data["sex"] = tokenizer.encode(mapped_sexes, add_special_tokens=False) |
| |
|
| | if token_args.add_disease_annotation: |
| | |
| | all_data["labels"] = all_data["disease"] |
| |
|
| | del data |
| | gc.collect() |
| |
|
| | dataset = Dataset.from_dict(all_data) |
| | num_samples = len(dataset) |
| | if token_args.max_shard_samples: |
| | num_shards = num_samples // min(token_args.max_shard_samples, num_samples) |
| | else: |
| | num_shards = 1 |
| |
|
| | |
| | relative_data_path = os.path.relpath(data_path, load_dir) |
| | relative_metadata_path = os.path.relpath(metadata_path, load_dir) |
| |
|
| | |
| | no_extension_data_path = os.path.splitext(relative_data_path)[0] |
| |
|
| | |
| | save_tokenized_data_path = os.path.join(save_dir, no_extension_data_path) |
| | save_metadata_path = os.path.join(save_dir, relative_metadata_path) |
| |
|
| | dataset.save_to_disk(save_tokenized_data_path, num_shards=num_shards) |
| | shutil.copy(metadata_path, save_metadata_path) |
| |
|
| |
|
| | |
| | |
| | |
| | def shard_hf_dataset(data_path: str, metadata_path: str, tokenization_args: Union[dict, TokenizationArgs]): |
| | """ |
| | Shards a Hugging Face Dataset into smaller chunks for efficient storage and processing. |
| | """ |
| | if isinstance(tokenization_args, dict): |
| | load_dir = tokenization_args["load_dir"] |
| | save_dir = tokenization_args["save_dir"] |
| | token_args_obj = TokenizationArgs(**tokenization_args) |
| | else: |
| | load_dir = tokenization_args.load_dir |
| | save_dir = tokenization_args.save_dir |
| | token_args_obj = tokenization_args |
| |
|
| | all_data = load_from_disk(data_path) |
| | num_samples = len(all_data) |
| | if token_args_obj.max_shard_samples: |
| | num_shards = num_samples // min(token_args_obj.max_shard_samples, num_samples) |
| | else: |
| | num_shards = 1 |
| |
|
| | save_tokenized_data_path = data_path.replace(load_dir, save_dir) |
| | save_metadata_path = metadata_path.replace(load_dir, save_dir) |
| | all_data.save_to_disk(save_tokenized_data_path, num_shards=num_shards) |
| | shutil.copy(metadata_path, save_metadata_path) |
| |
|
| | |
| | |
| | |
| | if __name__ == "__main__": |
| | parser = ArgumentParser(description="Tokenize an AnnData file for downstream ML tasks.") |
| | parser.add_argument( |
| | "--data_path", |
| | type=str, |
| | required=True, |
| | help="Path to the .h5ad file containing the preprocessed scRNA-seq data." |
| | ) |
| | parser.add_argument( |
| | "--metadata_path", |
| | type=str, |
| | required=True, |
| | help="Path to the JSON file containing metadata." |
| | ) |
| | parser.add_argument( |
| | "--config_path", |
| | type=str, |
| | required=True, |
| | help="Path to the JSON file specifying tokenization hyperparameters." |
| | ) |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | with open(args.config_path, "r") as f: |
| | tokenization_args = json.load(f) |
| |
|
| | |
| | tokenize( |
| | data_path=args.data_path, |
| | metadata_path=args.metadata_path, |
| | tokenization_args=tokenization_args |
| | ) |
| |
|