| import os, torch |
| import wandb |
| from tqdm import tqdm |
| from accelerate import Accelerator |
| from .training_module import DiffusionTrainingModule |
| from .logger import ModelLogger |
|
|
|
|
| def launch_training_task( |
| accelerator: Accelerator, |
| dataset: torch.utils.data.Dataset, |
| model: DiffusionTrainingModule, |
| model_logger: ModelLogger, |
| learning_rate: float = 1e-5, |
| weight_decay: float = 1e-2, |
| num_workers: int = 1, |
| save_steps: int = None, |
| num_epochs: int = 1, |
| args = None, |
| ): |
| if args is not None: |
| learning_rate = args.learning_rate |
| weight_decay = args.weight_decay |
| num_workers = args.dataset_num_workers |
| save_steps = args.save_steps |
| num_epochs = args.num_epochs |
| |
| optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) |
| scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) |
| dataloader = torch.utils.data.DataLoader( |
| dataset, |
| shuffle=True, |
| collate_fn=lambda x: x[0], |
| num_workers=num_workers, |
| ) |
| |
| model, optimizer, dataloader, scheduler = accelerator.prepare( |
| model, optimizer, dataloader, scheduler |
| ) |
|
|
| global_step = 0 |
| |
| for epoch_id in range(num_epochs): |
| |
| pbar = tqdm( |
| dataloader, |
| disable=not accelerator.is_local_main_process, |
| desc=f"Epoch {epoch_id}", |
| ) |
| for data in pbar: |
| with accelerator.accumulate(model): |
| optimizer.zero_grad() |
| if dataset.load_from_cache: |
| loss = model({}, inputs=data) |
| else: |
| loss = model(data) |
| accelerator.backward(loss) |
| optimizer.step() |
| model_logger.on_step_end(accelerator, model, save_steps) |
| scheduler.step() |
|
|
| global_step += 1 |
|
|
| |
| if ( |
| args is not None |
| and hasattr(args, "wandb_mode") |
| and args.wandb_mode != "disabled" |
| and accelerator.is_main_process |
| ): |
| log_every = getattr(args, "wandb_log_every", 10) |
| if global_step % log_every == 0: |
| |
| loss_value = loss.detach().float().item() |
| try: |
| lr = scheduler.get_last_lr()[0] |
| except Exception: |
| lr = optimizer.param_groups[0]["lr"] |
|
|
| wandb.log( |
| { |
| "train/loss": loss_value, |
| "train/lr": lr, |
| "train/epoch": epoch_id, |
| "train/step": global_step, |
| } |
| ) |
| |
|
|
| if save_steps is None: |
| model_logger.on_epoch_end(accelerator, model, epoch_id) |
| model_logger.on_training_end(accelerator, model, save_steps) |
|
|
|
|
| def launch_data_process_task( |
| accelerator: Accelerator, |
| dataset: torch.utils.data.Dataset, |
| model: DiffusionTrainingModule, |
| model_logger: ModelLogger, |
| num_workers: int = 8, |
| args = None, |
| ): |
| if args is not None: |
| num_workers = args.dataset_num_workers |
| |
| dataloader = torch.utils.data.DataLoader( |
| dataset, |
| shuffle=False, |
| collate_fn=lambda x: x[0], |
| num_workers=num_workers, |
| ) |
| model, dataloader = accelerator.prepare(model, dataloader) |
| |
| for data_id, data in enumerate(tqdm( |
| dataloader, |
| disable=not accelerator.is_local_main_process, |
| desc="Data process", |
| )): |
| with accelerator.accumulate(model): |
| with torch.no_grad(): |
| folder = os.path.join(model_logger.output_path, str(accelerator.process_index)) |
| os.makedirs(folder, exist_ok=True) |
| save_path = os.path.join( |
| model_logger.output_path, |
| str(accelerator.process_index), |
| f"{data_id}.pth", |
| ) |
| data = model(data) |
| torch.save(data, save_path) |
|
|