Virtual Cell β€” Patient Model

A patient-level disease classification model trained on single-cell RNA-seq data. Given a matrix of gene expression profiles (one row per cell), the model produces a disease-category prediction for the patient.

Model architecture

input  [batch, num_cells, 18301 genes]
  β†’ MLP cell embedder         β†’ [batch, num_cells, 512]
  β†’ Attention aggregator      β†’ [batch, 512]
  β†’ Dropout + Linear head     β†’ [batch, 10 classes]

Pretrained classification task

The pretrained checkpoint classifies patients into 10 disease categories: oncological, immune_inflammatory, neurological, metabolic_vascular, gastrointestinal, respiratory, epithelial_barrier, sensory_specialized, healthy_control, other.

The pretrained embedder generalizes well to other classification tasks. Common fine-tuning scenarios include binary sick vs. healthy or treatment response prediction β€” see Fine-tuning below.

Installation

All repository files are required to run train.py. Download them all (or clone the repo) and install dependencies:

pip install -r requirements.txt

wandb is optional and only needed when training with --wandb_project.

Tip: train.py uses multiple workers for data loading. A machine with at least 8 CPU cores is recommended for good throughput β€” set --num_workers to match your core count.

Quick start

Verify the model loads

import torch
from transformers import AutoModel

model = AutoModel.from_pretrained(
    "ConvergeBio/virtual-cell-patient",
    trust_remote_code=True,
).eval()

x = torch.randn(1, 500, 18_301)  # [batch, num_cells, num_genes]
with torch.no_grad():
    out = model(input_ids=x)

print(out.logits.shape)        # [1, 10]
print(out.logits.softmax(-1))

Inference on real data

from datasets import load_dataset
import torch
from transformers import AutoModel

ds = load_dataset("ConvergeBio/virtual-cell-patient-example", split="validation")

model = AutoModel.from_pretrained(
    "ConvergeBio/virtual-cell-patient",
    trust_remote_code=True,
).eval()

sample = torch.tensor(ds[0]["input_ids"]).unsqueeze(0)  # [1, 500, 18_301]
with torch.no_grad():
    out = model(input_ids=sample)

print(out.logits.softmax(-1))

Note: ConvergeBio/virtual-cell-patient-example is a minimal sample dataset intended only to verify the data format and run a quick end-to-end check. It contains a small number of patients and is not representative of a real training or evaluation distribution. Metrics produced from inference or training on this dataset should not be interpreted.

Preparing your data

train.py expects a HuggingFace dataset with train (and optionally validation) splits. Each row represents one cell sample for a patient, with the following required columns:

Column Shape Type Description
input_ids [500, 18301] float32 Log-normalized gene expression matrix, aligned to gene_names.txt
attention_mask [500] bool Cell mask (all ones for fixed cell count)
labels scalar int Class index
entity_id scalar int Patient identifier β€” groups augmented views of the same patient

Augmentation is strongly encouraged β€” multiple independent random cell samples from the same patient should be included as separate rows sharing the same entity_id. At inference, the model averages softmax probabilities across views for a more robust prediction. A factor of 5 augmentations per patient is a good default.

For a guide on building this dataset from raw scRNA-seq (h5ad) files, see the example dataset.

Fine-tuning

Binary classification (e.g. sick vs. healthy):

python train.py \
  --dataset_path <your_dataset> \
  --num_classes 2 \
  --freeze_embedder \
  --output_dir ./my_binary_model

--freeze_embedder keeps the pretrained cell embedder frozen and only trains the new head β€” recommended when your dataset is small.

Multi-class fine-tuning on a different label set:

python train.py \
  --dataset_path <your_dataset> \
  --num_classes <N> \
  --output_dir ./my_finetuned_model \
  --num_train_epochs 15 \
  --learning_rate 1e-4

Training from scratch

python train.py \
  --dataset_path <your_dataset> \
  --from_scratch \
  --output_dir ./my_scratch_model

Repository contents

File Description
modeling_virtual_cell.py Full model implementation
config.json Architecture config
gene_names.txt Ordered list of 18,301 HGNC gene symbols
train.py Fine-tuning / training script
requirements.txt Python dependencies
model.safetensors Pretrained weights

Citation

If you use this model, please cite:

@article{convergecell2026,
  author    = {ConvergeBio},
  title     = {ConvergeCELL: An end-to-end platform from patient transcriptomics to therapeutic hypotheses},
  year      = {2026},
  note      = {Preprint available on bioRxiv},
}

The model architecture and data processing approach were inspired by:

@article{liu2026pascient,
  author    = {Liu, T. and De Brouwer, E. and Verma, A. and Missarova, A. and
               Kuo, T. and others},
  title     = {Learning multi-cellular representations of single-cell transcriptomics
               data enables characterization of patient-level disease states},
  journal   = {Cell Systems},
  volume    = {17},
  pages     = {101570},
  year      = {2026},
}

License

Apache 2.0 β€” see LICENSE and NOTICE.

Downloads last month
17
Safetensors
Model size
80M params
Tensor type
F32
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support