| import os |
| import sys |
| os.environ["CUDA_VISIBLE_DEVICES"] = "6" |
|
|
| import torch |
| import wandb |
|
|
| from entangledcell_module_unseen import EntangledNetTrainCellUnseen |
| from entangledcell_module_three import EntangledNetTrainCellThree |
|
|
| |
| from dataloaders.three_branch_data import ThreeBranchTahoeDataModule |
| from dataloaders.clonidine_v2_data import ClonidineV2DataModule |
|
|
| from geo_metrics.metric_factory import DataManifoldMetric |
| from pytorch_lightning.loggers import WandbLogger |
| from pytorch_lightning import Trainer |
|
|
| from torchcfm.optimal_transport import OTPlanSampler |
| from parser import parse_args |
| from train_utils import load_config, merge_config |
| from bias import BiasForceTransformer, BiasForceTransformerNoVel |
|
|
| def main(): |
| |
| args = parse_args() |
| if args.config_path: |
| config = load_config(args.config_path) |
| args = merge_config(args, config) |
| |
| args.training = True |
| args.save_dir = args.save_dir |
| |
| |
| positions_dir = f"{args.save_dir}/positions" |
| if not os.path.exists(positions_dir): |
| os.makedirs(positions_dir) |
| |
| wandb.init(project="entangled-cell", |
| config=args, |
| name=args.run_name) |
| |
| torch.manual_seed(args.seed) |
| |
| ot_sampler = ( |
| OTPlanSampler(method=args.optimal_transport_method) |
| if args.optimal_transport_method != "None" |
| else None |
| ) |
| |
| |
| if args.data_name == "trametinib": |
| datamodule = ThreeBranchTahoeDataModule(args=args) |
| else: |
| datamodule = ClonidineV2DataModule(args=args) |
| |
| |
| data_manifold_metric = DataManifoldMetric( |
| args=args, |
| skipped_time_points=[], |
| datamodule=datamodule, |
| ) |
| |
| if args.vel_conditioned: |
| bias_net = BiasForceTransformer(args) |
| else: |
| print("Using no velocity conditioned model") |
| bias_net = BiasForceTransformerNoVel(args) |
| |
| timepoint_data = datamodule.get_timepoint_data() |
| |
| if args.data_name == "trametinib": |
| entangled_train = EntangledNetTrainCellThree(args=args, |
| bias_net=bias_net, |
| data_manifold_metric=data_manifold_metric, |
| timepoint_data=timepoint_data, |
| ot_sampler=ot_sampler, |
| vel_conditioned=args.vel_conditioned) |
| else: |
| entangled_train = EntangledNetTrainCellUnseen(args=args, |
| bias_net=bias_net, |
| data_manifold_metric=data_manifold_metric, |
| timepoint_data=timepoint_data, |
| ot_sampler=ot_sampler, |
| vel_conditioned=args.vel_conditioned) |
| |
| wandb_logger = WandbLogger() |
| |
| trainer = Trainer( |
| max_epochs=args.num_rollouts, |
| logger=wandb_logger, |
| num_sanity_val_steps=0, |
| default_root_dir=args.root_dir, |
| gradient_clip_val=None, |
| devices=[0], |
| ) |
| |
| trainer.fit( |
| entangled_train, datamodule=datamodule |
| ) |
| trainer.test(entangled_train, datamodule=datamodule) |
|
|
| if __name__ == "__main__": |
| main() |