| from v2a_model import V2AModel |
| from lib.datasets import create_dataset |
| import hydra |
| import pytorch_lightning as pl |
| from pytorch_lightning.loggers import WandbLogger |
| import os |
| import glob |
|
|
| @hydra.main(config_path="confs", config_name="base") |
| def main(opt): |
| pl.seed_everything(42) |
| print("Working dir:", os.getcwd()) |
|
|
| checkpoint_callback = pl.callbacks.ModelCheckpoint( |
| dirpath="checkpoints/", |
| filename="{epoch:04d}-{loss}", |
| save_on_train_epoch_end=True, |
| save_last=True) |
| logger = WandbLogger(project=opt.project_name, name=f"{opt.exp}/{opt.run}") |
|
|
| if not opt.model.incremental_sampling: |
| trainer = pl.Trainer( |
| gpus=1, |
| accelerator="gpu", |
| callbacks=[checkpoint_callback], |
| max_epochs=8000, |
| check_val_every_n_epoch=50, |
| logger=logger, |
| log_every_n_steps=1, |
| num_sanity_val_steps=0 |
| ) |
|
|
| model = V2AModel(opt) |
| trainset = create_dataset(opt.dataset.metainfo, opt.dataset.train) |
| validset = create_dataset(opt.dataset.metainfo, opt.dataset.valid) |
|
|
| if opt.model.is_continue == True: |
| checkpoint = sorted(glob.glob("checkpoints/*.ckpt"))[-1] |
| trainer.fit(model, trainset, validset, ckpt_path=checkpoint) |
| else: |
| trainer.fit(model, trainset, validset) |
|
|
| else: |
| |
| trainer = pl.Trainer( |
| gpus=1, |
| accelerator="gpu", |
| callbacks=[checkpoint_callback], |
| max_epochs=opt.model.epochs_increment_interval, |
| check_val_every_n_epoch=50, |
| logger=logger, |
| log_every_n_steps=1, |
| num_sanity_val_steps=0 |
| ) |
|
|
| model = V2AModel(opt) |
| trainset = create_dataset(opt.dataset.metainfo, opt.dataset.train) |
| validset = create_dataset(opt.dataset.metainfo, opt.dataset.valid) |
|
|
| if opt.model.is_continue == True: |
| checkpoint = sorted(glob.glob("checkpoints/*.ckpt"))[-1] |
| trainer.fit(model, trainset, validset, ckpt_path=checkpoint) |
| else: |
| trainer.fit(model, trainset, validset) |
|
|
| |
| for i in range(opt.model.incremental_sampling_steps): |
| if opt.model.increment_profile == "Squared": |
| opt.dataset.train.num_sample = int(opt.dataset.train.num_sample*2) |
| opt.model.ray_sampler.N_samples = int(opt.model.ray_sampler.N_samples/2) |
| opt.model.ray_sampler.N_samples_eval = int(opt.model.ray_sampler.N_samples_eval/2) |
| opt.model.ray_sampler.N_samples_extra = int(opt.model.ray_sampler.N_samples_extra/2) |
| if opt.model.increment_profile == "Linear": |
| if opt.model.incremental_sampling_steps > 3: |
| raise ValueError("The training will result in a negative number of samples, please adjust the values in train.py accordingly.") |
| opt.dataset.train.num_sample = int(opt.dataset.train.num_sample + 1024) |
| opt.model.ray_sampler.N_samples = int(opt.model.ray_sampler.N_samples - 16) |
| opt.model.ray_sampler.N_samples_eval = int(opt.model.ray_sampler.N_samples_eval - 32) |
| opt.model.ray_sampler.N_samples_extra = int(opt.model.ray_sampler.N_samples_extra - 8) |
|
|
| trainer = pl.Trainer( |
| gpus=1, |
| accelerator="gpu", |
| callbacks=[checkpoint_callback], |
| max_epochs=opt.model.epochs_increment_interval+(i+1)*opt.model.epochs_increment_interval, |
| check_val_every_n_epoch=50, |
| logger=logger, |
| log_every_n_steps=1, |
| num_sanity_val_steps=0 |
| ) |
|
|
| model = V2AModel(opt) |
| trainset = create_dataset(opt.dataset.metainfo, opt.dataset.train) |
| validset = create_dataset(opt.dataset.metainfo, opt.dataset.valid) |
|
|
| checkpoint = sorted(glob.glob("checkpoints/*.ckpt"))[-1] |
| trainer.fit(model, trainset, validset, ckpt_path=checkpoint) |
|
|
| |
| trainer = pl.Trainer( |
| gpus=1, |
| accelerator="gpu", |
| callbacks=[checkpoint_callback], |
| max_epochs=8000, |
| check_val_every_n_epoch=50, |
| logger=logger, |
| log_every_n_steps=1, |
| num_sanity_val_steps=0 |
| ) |
|
|
| model = V2AModel(opt) |
| trainset = create_dataset(opt.dataset.metainfo, opt.dataset.train) |
| validset = create_dataset(opt.dataset.metainfo, opt.dataset.valid) |
|
|
| checkpoint = sorted(glob.glob("checkpoints/*.ckpt"))[-1] |
| trainer.fit(model, trainset, validset, ckpt_path=checkpoint) |
|
|
| if __name__ == '__main__': |
| main() |