import os, torch import wandb # 新增 from tqdm import tqdm from accelerate import Accelerator from .training_module import DiffusionTrainingModule from .logger import ModelLogger def launch_training_task( accelerator: Accelerator, dataset: torch.utils.data.Dataset, model: DiffusionTrainingModule, model_logger: ModelLogger, learning_rate: float = 1e-5, weight_decay: float = 1e-2, num_workers: int = 1, save_steps: int = None, num_epochs: int = 1, args = None, ): if args is not None: learning_rate = args.learning_rate weight_decay = args.weight_decay num_workers = args.dataset_num_workers save_steps = args.save_steps num_epochs = args.num_epochs optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) dataloader = torch.utils.data.DataLoader( dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers, ) model, optimizer, dataloader, scheduler = accelerator.prepare( model, optimizer, dataloader, scheduler ) global_step = 0 # 用于 wandb 记录全局 step for epoch_id in range(num_epochs): # 只在本地主进程显示 tqdm,避免多卡重复进度条 pbar = tqdm( dataloader, disable=not accelerator.is_local_main_process, desc=f"Epoch {epoch_id}", ) for data in pbar: with accelerator.accumulate(model): optimizer.zero_grad() if dataset.load_from_cache: loss = model({}, inputs=data) else: loss = model(data) accelerator.backward(loss) optimizer.step() model_logger.on_step_end(accelerator, model, save_steps) scheduler.step() global_step += 1 # ============= wandb logging(只在主进程) ============= if ( args is not None and hasattr(args, "wandb_mode") and args.wandb_mode != "disabled" and accelerator.is_main_process ): log_every = getattr(args, "wandb_log_every", 10) if global_step % log_every == 0: # 这里直接用当前进程的 loss 就够了 loss_value = loss.detach().float().item() try: lr = scheduler.get_last_lr()[0] except Exception: lr = optimizer.param_groups[0]["lr"] wandb.log( { "train/loss": loss_value, "train/lr": lr, "train/epoch": epoch_id, "train/step": global_step, } ) # ======================================================= if save_steps is None: model_logger.on_epoch_end(accelerator, model, epoch_id) model_logger.on_training_end(accelerator, model, save_steps) def launch_data_process_task( accelerator: Accelerator, dataset: torch.utils.data.Dataset, model: DiffusionTrainingModule, model_logger: ModelLogger, num_workers: int = 8, args = None, ): if args is not None: num_workers = args.dataset_num_workers dataloader = torch.utils.data.DataLoader( dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers, ) model, dataloader = accelerator.prepare(model, dataloader) for data_id, data in enumerate(tqdm( dataloader, disable=not accelerator.is_local_main_process, desc="Data process", )): with accelerator.accumulate(model): with torch.no_grad(): folder = os.path.join(model_logger.output_path, str(accelerator.process_index)) os.makedirs(folder, exist_ok=True) save_path = os.path.join( model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth", ) data = model(data) torch.save(data, save_path)