Virtual Cell β Patient Model
A patient-level disease classification model trained on single-cell RNA-seq data. Given a matrix of gene expression profiles (one row per cell), the model produces a disease-category prediction for the patient.
Model architecture
input [batch, num_cells, 18301 genes]
β MLP cell embedder β [batch, num_cells, 512]
β Attention aggregator β [batch, 512]
β Dropout + Linear head β [batch, 10 classes]
Pretrained classification task
The pretrained checkpoint classifies patients into 10 disease categories:
oncological, immune_inflammatory, neurological, metabolic_vascular,
gastrointestinal, respiratory, epithelial_barrier, sensory_specialized,
healthy_control, other.
The pretrained embedder generalizes well to other classification tasks. Common fine-tuning scenarios include binary sick vs. healthy or treatment response prediction β see Fine-tuning below.
Installation
All repository files are required to run train.py. Download them all
(or clone the repo) and install dependencies:
pip install -r requirements.txt
wandb is optional and only needed when training with --wandb_project.
Tip:
train.pyuses multiple workers for data loading. A machine with at least 8 CPU cores is recommended for good throughput β set--num_workersto match your core count.
Quick start
Verify the model loads
import torch
from transformers import AutoModel
model = AutoModel.from_pretrained(
"ConvergeBio/virtual-cell-patient",
trust_remote_code=True,
).eval()
x = torch.randn(1, 500, 18_301) # [batch, num_cells, num_genes]
with torch.no_grad():
out = model(input_ids=x)
print(out.logits.shape) # [1, 10]
print(out.logits.softmax(-1))
Inference on real data
from datasets import load_dataset
import torch
from transformers import AutoModel
ds = load_dataset("ConvergeBio/virtual-cell-patient-example", split="validation")
model = AutoModel.from_pretrained(
"ConvergeBio/virtual-cell-patient",
trust_remote_code=True,
).eval()
sample = torch.tensor(ds[0]["input_ids"]).unsqueeze(0) # [1, 500, 18_301]
with torch.no_grad():
out = model(input_ids=sample)
print(out.logits.softmax(-1))
Note:
ConvergeBio/virtual-cell-patient-exampleis a minimal sample dataset intended only to verify the data format and run a quick end-to-end check. It contains a small number of patients and is not representative of a real training or evaluation distribution. Metrics produced from inference or training on this dataset should not be interpreted.
Preparing your data
train.py expects a HuggingFace dataset with train (and optionally validation)
splits. Each row represents one cell sample for a patient, with the following
required columns:
| Column | Shape | Type | Description |
|---|---|---|---|
input_ids |
[500, 18301] | float32 | Log-normalized gene expression matrix, aligned to gene_names.txt |
attention_mask |
[500] | bool | Cell mask (all ones for fixed cell count) |
labels |
scalar | int | Class index |
entity_id |
scalar | int | Patient identifier β groups augmented views of the same patient |
Augmentation is strongly encouraged β multiple independent random cell samples
from the same patient should be included as separate rows sharing the same
entity_id. At inference, the model averages softmax probabilities across views
for a more robust prediction. A factor of 5 augmentations per patient is a good
default.
For a guide on building this dataset from raw scRNA-seq (h5ad) files, see the example dataset.
Fine-tuning
Binary classification (e.g. sick vs. healthy):
python train.py \
--dataset_path <your_dataset> \
--num_classes 2 \
--freeze_embedder \
--output_dir ./my_binary_model
--freeze_embedder keeps the pretrained cell embedder frozen and only trains
the new head β recommended when your dataset is small.
Multi-class fine-tuning on a different label set:
python train.py \
--dataset_path <your_dataset> \
--num_classes <N> \
--output_dir ./my_finetuned_model \
--num_train_epochs 15 \
--learning_rate 1e-4
Training from scratch
python train.py \
--dataset_path <your_dataset> \
--from_scratch \
--output_dir ./my_scratch_model
Repository contents
| File | Description |
|---|---|
modeling_virtual_cell.py |
Full model implementation |
config.json |
Architecture config |
gene_names.txt |
Ordered list of 18,301 HGNC gene symbols |
train.py |
Fine-tuning / training script |
requirements.txt |
Python dependencies |
model.safetensors |
Pretrained weights |
Citation
If you use this model, please cite:
@article{convergecell2026,
author = {ConvergeBio},
title = {ConvergeCELL: An end-to-end platform from patient transcriptomics to therapeutic hypotheses},
year = {2026},
note = {Preprint available on bioRxiv},
}
The model architecture and data processing approach were inspired by:
@article{liu2026pascient,
author = {Liu, T. and De Brouwer, E. and Verma, A. and Missarova, A. and
Kuo, T. and others},
title = {Learning multi-cellular representations of single-cell transcriptomics
data enables characterization of patient-level disease states},
journal = {Cell Systems},
volume = {17},
pages = {101570},
year = {2026},
}
License
- Downloads last month
- 17