| | import hashlib |
| | from pathlib import Path |
| |
|
| | import anndata as ad |
| | import numpy as np |
| | import pytest |
| | import torch |
| |
|
| | from teddy.data_processing.tokenization.tokenization import ( |
| | _bin_values, |
| | _build_batch_tensors, |
| | _check_genes_in_tokenizer, |
| | _prepare_tokenizer_args, |
| | _rank_continuous, |
| | tokenize, |
| | ) |
| |
|
| | |
| | |
| | |
| | def test_bin_values_no_sorting(): |
| | """ |
| | Test _bin_values with no_sorting=True on a simple array. |
| | """ |
| | vals_list = [np.array([0, 1, 2, 3], dtype=float)] |
| |
|
| | class TokenArgs: |
| | include_zero_genes = True |
| | bins = 2 |
| |
|
| | token_args = TokenArgs() |
| |
|
| | binned = _bin_values(vals_list, token_args, no_sorting=True) |
| | assert len(binned) == 1 |
| | |
| | |
| | |
| | |
| | result = binned[0] |
| | expected = np.array([1, 1, 2, 2], dtype=float) |
| | assert (result.numpy() == expected).all(), f"Got {result.numpy()}, expected {expected}" |
| |
|
| |
|
| | def test_bin_values_with_sorting(): |
| | """ |
| | Test _bin_values with no_sorting=False (the 'positional chunk' approach). |
| | """ |
| | vals_list = [np.array([5, 4, 3, 0], dtype=float)] |
| |
|
| | class TokenArgs: |
| | include_zero_genes = True |
| | bins = 2 |
| |
|
| | token_args = TokenArgs() |
| | binned = _bin_values(vals_list, token_args, no_sorting=False) |
| | assert len(binned) == 1 |
| | |
| | |
| | result = binned[0] |
| | assert len(result) == 4, f"Expected 4 bins in the result, got {len(result)}" |
| |
|
| |
|
| | def test_rank_continuous_normal(): |
| | """ |
| | Should produce a descending linear scale from ~-1 to +1 across the entire array. |
| | """ |
| | arr = np.array([3, 2, 1, 0], dtype=float) |
| |
|
| | class TokenArgs: |
| | pass |
| |
|
| | token_args = TokenArgs() |
| | ranked = _rank_continuous(arr, token_args) |
| | |
| | |
| | |
| | assert ranked.shape[0] == 4 |
| | assert ranked[0] > ranked[1] |
| | assert torch.isclose(ranked.min(), torch.tensor(-1.0), atol=1e-5) |
| |
|
| |
|
| |
|
| | def test_prepare_tokenizer_args_dict(): |
| | """ |
| | Test that a dict tokenization_args is converted to TokenizationArgs object properly, |
| | and random seeds are set if gene_seed is not None. |
| | """ |
| |
|
| | args_dict = { |
| | "load_dir": "/mock/load", |
| | "save_dir": "/mock/save", |
| | "gene_seed": 42, |
| | "tokenizer_name_or_path": "some/tokenizer", |
| | } |
| | token_args_obj, load_dir, save_dir = _prepare_tokenizer_args(args_dict) |
| | assert load_dir == "/mock/load" |
| | assert save_dir == "/mock/save" |
| | assert token_args_obj.gene_seed == 42 |
| |
|
| |
|
| | def test_check_genes_in_tokenizer(): |
| | """ |
| | Test _check_genes_in_tokenizer with a minimal mock GeneTokenizer vocab. |
| | """ |
| | |
| | import anndata as ad |
| |
|
| | from teddy.tokenizer.gene_tokenizer import GeneTokenizer |
| |
|
| | data = ad.AnnData(X=np.zeros((10, 4))) |
| | data.var["gene_name"] = ["G1", "G2", "G3", "G4"] |
| | |
| | class MockGeneTokenizer: |
| | def __init__(self, vocab): |
| | self.vocab = vocab |
| |
|
| | mock_tokenizer = MockGeneTokenizer({"G2": 1, "G3": 2}) |
| |
|
| | gene_in_vocab, coding_genes, ratio = _check_genes_in_tokenizer( |
| | data, |
| | gene_id_column="gene_name", |
| | tokenizer=mock_tokenizer |
| | ) |
| |
|
| | |
| | assert len(gene_in_vocab) == 2 |
| | assert ratio == 0.5 |
| |
|
| |
|
| | def test_build_batch_tensors(): |
| | """ |
| | Confirm _build_batch_tensors performs topk or random_genes if specified, and returns expected shapes. |
| | We'll do a simple topk test. |
| | """ |
| | X_batch = torch.tensor([[5, 0, 1], [0, 2, 2]], dtype=torch.float) |
| | token_array = torch.tensor([10, 11, 12]) |
| |
|
| | class TokenArgs: |
| | max_seq_len = 3 |
| | add_cls = False |
| | random_genes = False |
| |
|
| | gene_ids, vals, labels_list, decoder_vals = _build_batch_tensors(X_batch, token_array, TokenArgs()) |
| | |
| | |
| | |
| | assert len(gene_ids) == 2 |
| | assert labels_list is None |
| | |
| | assert len(gene_ids[0]) == 3 |
| | assert len(vals[0]) == 3 |
| | assert decoder_vals is None |
| |
|
| |
|