| import os |
| import argparse |
| import pytorch_lightning as pl |
| from braceexpand import braceexpand |
| from torch.utils.data import DataLoader |
| from datasets.webdataset import MultiWebDataset |
|
|
| from cldm.logger import ImageLogger |
| from cldm.model import create_model, load_state_dict |
| from torch.utils.data import ConcatDataset |
| from cldm.hack import disable_verbosity, enable_sliced_attention |
| from omegaconf import OmegaConf |
| import torch |
|
|
| from datasets.base import BaseDataset |
|
|
| class BaseLogic(BaseDataset): |
| def __init__(self, area_ratio, obj_thr): |
| self.area_ratio = area_ratio |
| self.obj_thr = obj_thr |
|
|
| print("Number of GPUs available: ", torch.cuda.device_count()) |
| print("Current device: ", torch.cuda.current_device()) |
| print("Device name: ", torch.cuda.get_device_name(0)) |
|
|
| def get_args_parser(): |
| parser = argparse.ArgumentParser('PICS Training Script', add_help=False) |
|
|
| parser.add_argument('--resume_path', required=None, type=str) |
| parser.add_argument('--root_dir', required=True, type=str) |
| parser.add_argument('--batch_size', default=1, type=int) |
| parser.add_argument('--limit_train_batches', default=1, type=float) |
| parser.add_argument('--logger_freq', default=1000, type=int) |
| parser.add_argument('--learning_rate', default=1e-5, type=float) |
| parser.add_argument('--is_joint', action='store_true', help="Joint/Seprate training") |
| parser.add_argument("--dataset_name", type=str, default='lvis', help="Dataset name") |
|
|
| return parser |
|
|
| def main(args): |
| save_memory = False |
| disable_verbosity() |
| if save_memory: |
| enable_sliced_attention() |
| |
| sd_locked = False |
| only_mid_control = False |
| accumulate_grad_batches = 1 |
| obj_thr = {'obj_thr': 2} |
|
|
| model = create_model('./configs/pics.yaml').cpu() |
| if args.resume_path and os.path.exists(args.resume_path): |
| print(f"Loading checkpoint from: {args.resume_path}") |
| checkpoint = load_state_dict(args.resume_path, location='cpu') |
| model.load_state_dict(checkpoint, strict=False) |
| else: |
| print("No checkpoint found or provided. Training from scratch...") |
|
|
| model.learning_rate = args.learning_rate |
| model.sd_locked = sd_locked |
| model.only_mid_control = only_mid_control |
|
|
| DConf = OmegaConf.load('./configs/datasets.yaml') |
|
|
| if args.is_joint: |
| |
| weights = {'LVIS': 3, 'VITONHD': 6, 'Objects365': 1, 'Cityscapes': 18, 'MapillaryVistas': 18, 'BDD100K': 18} |
| else: |
| if args.dataset_name == 'lvis': |
| weights = {'LVIS': 1, 'VITONHD': 0, 'Objects365': 0, 'Cityscapes': 0, 'MapillaryVistas': 0, 'BDD100K': 0} |
| elif args.dataset_name == 'vitonhd': |
| weights = {'LVIS': 0, 'VITONHD': 1, 'Objects365': 0, 'Cityscapes': 0, 'MapillaryVistas': 0, 'BDD100K': 0} |
| elif args.dataset_name == 'object365': |
| weights = {'LVIS': 0, 'VITONHD': 0, 'Objects365': 1, 'Cityscapes': 0, 'MapillaryVistas': 0, 'BDD100K': 0} |
| elif args.dataset_name == 'cityscapes': |
| weights = {'LVIS': 0, 'VITONHD': 0, 'Objects365': 0, 'Cityscapes': 1, 'MapillaryVistas': 0, 'BDD100K': 0} |
| elif args.dataset_name == 'mapillaryvistas': |
| weights = {'LVIS': 0, 'VITONHD': 0, 'Objects365': 0, 'Cityscapes': 0, 'MapillaryVistas': 1, 'BDD100K': 0} |
| elif args.dataset_name == 'bdd100k': |
| weights = {'LVIS': 0, 'VITONHD': 0, 'Objects365': 0, 'Cityscapes': 0, 'MapillaryVistas': 0, 'BDD100K': 1} |
| else: |
| raise ValueError(f"Unsupported dataset name: {args.dataset_name}") |
| |
| all_urls = [] |
| dataset_shards = [ |
| ('LVIS', DConf.Train.LVIS.shards), |
| ('VITONHD', DConf.Train.VITONHD.shards), |
| ('Objects365', DConf.Train.Objects365.shards), |
| ('Cityscapes', DConf.Train.Cityscapes.shards), |
| ('MapillaryVistas', DConf.Train.MapillaryVistas.shards), |
| ('BDD100K', DConf.Train.BDD100K.shards) |
| ] |
|
|
| for name, path in dataset_shards: |
| expanded = list(braceexpand(path)) |
| all_urls.extend(expanded * weights.get(name, 1)) |
| |
| import random |
| random.shuffle(all_urls) |
|
|
| logic_helper = BaseLogic( |
| area_ratio=DConf.Defaults.area_ratio, |
| obj_thr=DConf.Defaults.obj_thr |
| ) |
|
|
| dataset = MultiWebDataset( |
| urls=all_urls, |
| construct_collage_fn=logic_helper._construct_collage, |
| shuffle_size=10000, |
| seed=42, |
| decode_mode="pil", |
| ) |
|
|
| dataloader = DataLoader( |
| dataset, |
| num_workers=8, |
| batch_size=args.batch_size, |
| ) |
| |
| logger = ImageLogger(batch_frequency=args.logger_freq, log_images_kwargs=obj_thr) |
|
|
| checkpoint_callback = pl.callbacks.ModelCheckpoint( |
| dirpath=os.path.join(args.root_dir, 'checkpoints'), |
| filename='pics-{step:06d}', |
| every_n_train_steps=2000, |
| save_top_k=-1, |
| ) |
|
|
| trainer = pl.Trainer( |
| default_root_dir=args.root_dir, |
| limit_train_batches=args.limit_train_batches, |
| accelerator="gpu", |
| devices=1, |
| precision=16, |
| callbacks=[logger, checkpoint_callback], |
| accumulate_grad_batches=accumulate_grad_batches, |
| max_epochs=50, |
| val_check_interval=2000, |
| ) |
| trainer.fit(model, dataloader) |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser('PICS Training', parents=[get_args_parser()]) |
| args = parser.parse_args() |
| main(args) |
|
|
|
|